1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

[suggestion] Add helper interface for ProxyBalancer interface (#2316)

* [suggestion] Add helper interface for ProxyBalancer interface

* Update proxy_test.go

* addressed code review comments

* address pr comments

* clean up

* return error
This commit is contained in:
Hristo Hristov 2022-10-29 21:54:23 +03:00 committed by GitHub
parent 8f2bf82982
commit 0ce73028d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 3 deletions

View File

@ -72,6 +72,11 @@ type (
Next(echo.Context) *ProxyTarget Next(echo.Context) *ProxyTarget
} }
// TargetProvider defines an interface that gives the opportunity for balancer to return custom errors when selecting target.
TargetProvider interface {
NextTarget(echo.Context) (*ProxyTarget, error)
}
commonBalancer struct { commonBalancer struct {
targets []*ProxyTarget targets []*ProxyTarget
mutex sync.RWMutex mutex sync.RWMutex
@ -223,6 +228,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
} }
} }
provider, isTargetProvider := config.Balancer.(TargetProvider)
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) (err error) {
if config.Skipper(c) { if config.Skipper(c) {
@ -231,7 +237,16 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
tgt := config.Balancer.Next(c)
var tgt *ProxyTarget
if isTargetProvider {
tgt, err = provider.NextTarget(c)
if err != nil {
return err
}
} else {
tgt = config.Balancer.Next(c)
}
c.Set(config.ContextKey, tgt) c.Set(config.ContextKey, tgt)
if err := rewriteURL(config.RegexRewrite, req); err != nil { if err := rewriteURL(config.RegexRewrite, req); err != nil {

View File

@ -18,7 +18,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
//Assert expected with url.EscapedPath method to obtain the path. // Assert expected with url.EscapedPath method to obtain the path.
func TestProxy(t *testing.T) { func TestProxy(t *testing.T) {
// Setup // Setup
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -31,7 +31,6 @@ func TestProxy(t *testing.T) {
})) }))
defer t2.Close() defer t2.Close()
url2, _ := url.Parse(t2.URL) url2, _ := url.Parse(t2.URL)
targets := []*ProxyTarget{ targets := []*ProxyTarget{
{ {
Name: "target 1", Name: "target 1",
@ -122,6 +121,55 @@ func TestProxy(t *testing.T) {
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
} }
type testProvider struct {
*commonBalancer
target *ProxyTarget
err error
}
func (p *testProvider) Next(c echo.Context) *ProxyTarget {
return &ProxyTarget{}
}
func (p *testProvider) NextTarget(c echo.Context) (*ProxyTarget, error) {
return p.target, p.err
}
func TestTargetProvider(t *testing.T) {
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "target 1")
}))
defer t1.Close()
url1, _ := url.Parse(t1.URL)
e := echo.New()
tp := &testProvider{commonBalancer: new(commonBalancer)}
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
e.Use(Proxy(tp))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
e.ServeHTTP(rec, req)
body := rec.Body.String()
assert.Equal(t, "target 1", body)
}
func TestFailNextTarget(t *testing.T) {
url1, err := url.Parse("http://dummy:8080")
assert.Nil(t, err)
e := echo.New()
tp := &testProvider{commonBalancer: new(commonBalancer)}
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
e.Use(Proxy(tp))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
e.ServeHTTP(rec, req)
body := rec.Body.String()
assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
}
func TestProxyRealIPHeader(t *testing.T) { func TestProxyRealIPHeader(t *testing.T) {
// Setup // Setup
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))