1
0
mirror of https://github.com/labstack/echo.git synced 2024-11-24 08:22:21 +02:00

Support retries of failed proxy requests (#2414)

Support retries of failed proxy requests
This commit is contained in:
mikemherron 2023-05-12 18:36:24 +01:00 committed by GitHub
parent deb17d2388
commit 0ae74648b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 437 additions and 41 deletions

View File

@ -29,6 +29,33 @@ type (
// Required.
Balancer ProxyBalancer
// RetryCount defines the number of times a failed proxied request should be retried
// using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
RetryCount int
// RetryFilter defines a function used to determine if a failed request to a
// ProxyTarget should be retried. The RetryFilter will only be called when the number
// of previous retries is less than RetryCount. If the function returns true, the
// request will be retried. The provided error indicates the reason for the request
// failure. When the ProxyTarget is unavailable, the error will be an instance of
// echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error
// will indicate an internal error in the Proxy middleware. When a RetryFilter is not
// specified, all requests that fail with http.StatusBadGateway will be retried. A custom
// RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
// only called when the request to the target fails, or an internal error in the Proxy
// middleware has occurred. Successful requests that return a non-200 response code cannot
// be retried.
RetryFilter func(c echo.Context, e error) bool
// ErrorHandler defines a function which can be used to return custom errors from
// the Proxy middleware. ErrorHandler is only invoked when there has been
// either an internal error in the Proxy middleware or the ProxyTarget is
// unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
// when a ProxyTarget returns a non-200 response. In these cases, the response
// is already written so errors cannot be modified. ErrorHandler is only
// invoked after all retry attempts have been exhausted.
ErrorHandler func(c echo.Context, err error) error
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Examples:
@ -71,7 +98,8 @@ type (
Next(echo.Context) *ProxyTarget
}
// TargetProvider defines an interface that gives the opportunity for balancer to return custom errors when selecting target.
// TargetProvider defines an interface that gives the opportunity for balancer
// to return custom errors when selecting target.
TargetProvider interface {
NextTarget(echo.Context) (*ProxyTarget, error)
}
@ -107,14 +135,14 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := c.Response().Hijack()
if err != nil {
c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err))
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
return
}
defer in.Close()
out, err := net.Dial("tcp", t.URL.Host)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err)))
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
return
}
defer out.Close()
@ -122,7 +150,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
// Write header
err = r.Write(out)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err)))
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL)))
return
}
@ -136,7 +164,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
go cp(in, out)
err = <-errCh
if err != nil && err != io.EOF {
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err))
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err, t.URL))
}
})
}
@ -200,7 +228,12 @@ func (b *randomBalancer) Next(c echo.Context) *ProxyTarget {
return b.targets[b.random.Intn(len(b.targets))]
}
// Next returns an upstream target using round-robin technique.
// Next returns an upstream target using round-robin technique. In the case
// where a previously failed request is being retried, the round-robin
// balancer will attempt to use the next target relative to the original
// request. If the list of targets held by the balancer is modified while a
// failed request is being retried, it is possible that the balancer will
// return the original failed target.
//
// Note: `nil` is returned in case upstream target list is empty.
func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
@ -211,13 +244,29 @@ func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
} else if len(b.targets) == 1 {
return b.targets[0]
}
// reset the index if out of bounds
if b.i >= len(b.targets) {
b.i = 0
var i int
const lastIdxKey = "_round_robin_last_index"
// This request is a retry, start from the index of the previous
// target to ensure we don't attempt to retry the request with
// the same failed target
if c.Get(lastIdxKey) != nil {
i = c.Get(lastIdxKey).(int)
i++
if i >= len(b.targets) {
i = 0
}
} else {
// This is a first time request, use the global index
if b.i >= len(b.targets) {
b.i = 0
}
i = b.i
b.i++
}
t := b.targets[b.i]
b.i++
return t
c.Set(lastIdxKey, i)
return b.targets[i]
}
// Proxy returns a Proxy middleware.
@ -232,14 +281,26 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
// ProxyWithConfig returns a Proxy middleware with config.
// See: `Proxy()`
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
if config.Balancer == nil {
panic("echo: proxy middleware requires balancer")
}
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultProxyConfig.Skipper
}
if config.Balancer == nil {
panic("echo: proxy middleware requires balancer")
if config.RetryFilter == nil {
config.RetryFilter = func(c echo.Context, e error) bool {
if httpErr, ok := e.(*echo.HTTPError); ok {
return httpErr.Code == http.StatusBadGateway
}
return false
}
}
if config.ErrorHandler == nil {
config.ErrorHandler = func(c echo.Context, err error) error {
return err
}
}
if config.Rewrite != nil {
if config.RegexRewrite == nil {
config.RegexRewrite = make(map[*regexp.Regexp]string)
@ -250,28 +311,17 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
}
provider, isTargetProvider := config.Balancer.(TargetProvider)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
var tgt *ProxyTarget
if isTargetProvider {
tgt, err = provider.NextTarget(c)
if err != nil {
return err
}
} else {
tgt = config.Balancer.Next(c)
}
c.Set(config.ContextKey, tgt)
if err := rewriteURL(config.RegexRewrite, req); err != nil {
return err
return config.ErrorHandler(c, err)
}
// Fix header
@ -287,19 +337,49 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
}
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default:
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
}
if e, ok := c.Get("_error").(error); ok {
err = e
}
retries := config.RetryCount
for {
var tgt *ProxyTarget
var err error
if isTargetProvider {
tgt, err = provider.NextTarget(c)
if err != nil {
return config.ErrorHandler(c, err)
}
} else {
tgt = config.Balancer.Next(c)
}
return
c.Set(config.ContextKey, tgt)
//If retrying a failed request, clear any previous errors from
//context here so that balancers have the option to check for
//errors that occurred using previous target
if retries < config.RetryCount {
c.Set("_error", nil)
}
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default:
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
}
err, hasError := c.Get("_error").(error)
if !hasError {
return nil
}
retry := retries > 0 && config.RetryFilter(c, err)
if !retry {
return config.ErrorHandler(c, err)
}
retries--
}
}
}
}

View File

@ -3,6 +3,7 @@ package middleware
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
@ -393,6 +394,321 @@ func TestProxyError(t *testing.T) {
assert.Equal(t, http.StatusBadGateway, rec.Code)
}
func TestProxyRetries(t *testing.T) {
newServer := func(res int) (*url.URL, *httptest.Server) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(res)
}),
)
targetURL, _ := url.Parse(server.URL)
return targetURL, server
}
targetURL, server := newServer(http.StatusOK)
defer server.Close()
goodTarget := &ProxyTarget{
Name: "Good",
URL: targetURL,
}
targetURL, server = newServer(http.StatusBadRequest)
defer server.Close()
goodTargetWith40X := &ProxyTarget{
Name: "Good with 40X",
URL: targetURL,
}
targetURL, _ = url.Parse("http://127.0.0.1:27121")
badTarget := &ProxyTarget{
Name: "Bad",
URL: targetURL,
}
alwaysRetryFilter := func(c echo.Context, e error) bool { return true }
neverRetryFilter := func(c echo.Context, e error) bool { return false }
testCases := []struct {
name string
retryCount int
retryFilters []func(c echo.Context, e error) bool
targets []*ProxyTarget
expectedResponse int
}{
{
name: "retry count 0 does not attempt retry on fail",
targets: []*ProxyTarget{
badTarget,
goodTarget,
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 1 does not attempt retry on success",
retryCount: 1,
targets: []*ProxyTarget{
goodTarget,
},
expectedResponse: http.StatusOK,
},
{
name: "retry count 1 does retry on handler return true",
retryCount: 1,
retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
goodTarget,
},
expectedResponse: http.StatusOK,
},
{
name: "retry count 1 does not retry on handler return false",
retryCount: 1,
retryFilters: []func(c echo.Context, e error) bool{
neverRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
goodTarget,
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 2 returns error when no more retries left",
retryCount: 2,
retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
badTarget,
badTarget,
goodTarget, //Should never be reached as only 2 retries
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 2 returns error when retries left but handler returns false",
retryCount: 3,
retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
neverRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
badTarget,
badTarget,
goodTarget, //Should never be reached as retry handler returns false on 2nd check
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 3 succeeds",
retryCount: 3,
retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
alwaysRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
badTarget,
badTarget,
goodTarget,
},
expectedResponse: http.StatusOK,
},
{
name: "40x responses are not retried",
retryCount: 1,
targets: []*ProxyTarget{
goodTargetWith40X,
goodTarget,
},
expectedResponse: http.StatusBadRequest,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
retryFilterCall := 0
retryFilter := func(c echo.Context, e error) bool {
if len(tc.retryFilters) == 0 {
assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall))
}
retryFilterCall++
nextRetryFilter := tc.retryFilters[0]
tc.retryFilters = tc.retryFilters[1:]
return nextRetryFilter(c, e)
}
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Balancer: NewRoundRobinBalancer(tc.targets),
RetryCount: tc.retryCount,
RetryFilter: retryFilter,
},
))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectedResponse, rec.Code)
if len(tc.retryFilters) > 0 {
assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters)))
}
})
}
}
func TestProxyRetryWithBackendTimeout(t *testing.T) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.ResponseHeaderTimeout = time.Millisecond * 500
timeoutBackend := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1 * time.Second)
w.WriteHeader(404)
}),
)
defer timeoutBackend.Close()
timeoutTargetURL, _ := url.Parse(timeoutBackend.URL)
goodBackend := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}),
)
defer goodBackend.Close()
goodTargetURL, _ := url.Parse(goodBackend.URL)
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Transport: transport,
Balancer: NewRoundRobinBalancer([]*ProxyTarget{
{
Name: "Timeout",
URL: timeoutTargetURL,
},
{
Name: "Good",
URL: goodTargetURL,
},
}),
RetryCount: 1,
},
))
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, 200, rec.Code)
}()
}
wg.Wait()
}
func TestProxyErrorHandler(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
goodURL, _ := url.Parse(server.URL)
defer server.Close()
goodTarget := &ProxyTarget{
Name: "Good",
URL: goodURL,
}
badURL, _ := url.Parse("http://127.0.0.1:27121")
badTarget := &ProxyTarget{
Name: "Bad",
URL: badURL,
}
transformedError := errors.New("a new error")
testCases := []struct {
name string
target *ProxyTarget
errorHandler func(c echo.Context, e error) error
expectFinalError func(t *testing.T, err error)
}{
{
name: "Error handler not invoked when request success",
target: goodTarget,
errorHandler: func(c echo.Context, e error) error {
assert.FailNow(t, "error handler should not be invoked")
return e
},
},
{
name: "Error handler invoked when request fails",
target: badTarget,
errorHandler: func(c echo.Context, e error) error {
httpErr, ok := e.(*echo.HTTPError)
assert.True(t, ok, "expected http error to be passed to handler")
assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
return transformedError
},
expectFinalError: func(t *testing.T, err error) {
assert.Equal(t, transformedError, err, "transformed error not returned from proxy")
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}),
ErrorHandler: tc.errorHandler,
},
))
errorHandlerCalled := false
e.HTTPErrorHandler = func(err error, c echo.Context) {
errorHandlerCalled = true
tc.expectFinalError(t, err)
e.DefaultHTTPErrorHandler(err, c)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
if !errorHandlerCalled && tc.expectFinalError != nil {
t.Fatalf("error handler was not called")
}
})
}
}
func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
var timeoutStop sync.WaitGroup
timeoutStop.Add(1)