mirror of
https://github.com/labstack/echo.git
synced 2024-11-24 08:22:21 +02:00
Support retries of failed proxy requests (#2414)
Support retries of failed proxy requests
This commit is contained in:
parent
deb17d2388
commit
0ae74648b9
@ -29,6 +29,33 @@ type (
|
||||
// Required.
|
||||
Balancer ProxyBalancer
|
||||
|
||||
// RetryCount defines the number of times a failed proxied request should be retried
|
||||
// using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
|
||||
RetryCount int
|
||||
|
||||
// RetryFilter defines a function used to determine if a failed request to a
|
||||
// ProxyTarget should be retried. The RetryFilter will only be called when the number
|
||||
// of previous retries is less than RetryCount. If the function returns true, the
|
||||
// request will be retried. The provided error indicates the reason for the request
|
||||
// failure. When the ProxyTarget is unavailable, the error will be an instance of
|
||||
// echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error
|
||||
// will indicate an internal error in the Proxy middleware. When a RetryFilter is not
|
||||
// specified, all requests that fail with http.StatusBadGateway will be retried. A custom
|
||||
// RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
|
||||
// only called when the request to the target fails, or an internal error in the Proxy
|
||||
// middleware has occurred. Successful requests that return a non-200 response code cannot
|
||||
// be retried.
|
||||
RetryFilter func(c echo.Context, e error) bool
|
||||
|
||||
// ErrorHandler defines a function which can be used to return custom errors from
|
||||
// the Proxy middleware. ErrorHandler is only invoked when there has been
|
||||
// either an internal error in the Proxy middleware or the ProxyTarget is
|
||||
// unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
|
||||
// when a ProxyTarget returns a non-200 response. In these cases, the response
|
||||
// is already written so errors cannot be modified. ErrorHandler is only
|
||||
// invoked after all retry attempts have been exhausted.
|
||||
ErrorHandler func(c echo.Context, err error) error
|
||||
|
||||
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
|
||||
// retrieved by index e.g. $1, $2 and so on.
|
||||
// Examples:
|
||||
@ -71,7 +98,8 @@ type (
|
||||
Next(echo.Context) *ProxyTarget
|
||||
}
|
||||
|
||||
// TargetProvider defines an interface that gives the opportunity for balancer to return custom errors when selecting target.
|
||||
// TargetProvider defines an interface that gives the opportunity for balancer
|
||||
// to return custom errors when selecting target.
|
||||
TargetProvider interface {
|
||||
NextTarget(echo.Context) (*ProxyTarget, error)
|
||||
}
|
||||
@ -107,14 +135,14 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
in, _, err := c.Response().Hijack()
|
||||
if err != nil {
|
||||
c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err))
|
||||
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
|
||||
return
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
out, err := net.Dial("tcp", t.URL.Host)
|
||||
if err != nil {
|
||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err)))
|
||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
|
||||
return
|
||||
}
|
||||
defer out.Close()
|
||||
@ -122,7 +150,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
||||
// Write header
|
||||
err = r.Write(out)
|
||||
if err != nil {
|
||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err)))
|
||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL)))
|
||||
return
|
||||
}
|
||||
|
||||
@ -136,7 +164,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
||||
go cp(in, out)
|
||||
err = <-errCh
|
||||
if err != nil && err != io.EOF {
|
||||
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err))
|
||||
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err, t.URL))
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -200,7 +228,12 @@ func (b *randomBalancer) Next(c echo.Context) *ProxyTarget {
|
||||
return b.targets[b.random.Intn(len(b.targets))]
|
||||
}
|
||||
|
||||
// Next returns an upstream target using round-robin technique.
|
||||
// Next returns an upstream target using round-robin technique. In the case
|
||||
// where a previously failed request is being retried, the round-robin
|
||||
// balancer will attempt to use the next target relative to the original
|
||||
// request. If the list of targets held by the balancer is modified while a
|
||||
// failed request is being retried, it is possible that the balancer will
|
||||
// return the original failed target.
|
||||
//
|
||||
// Note: `nil` is returned in case upstream target list is empty.
|
||||
func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
|
||||
@ -211,13 +244,29 @@ func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
|
||||
} else if len(b.targets) == 1 {
|
||||
return b.targets[0]
|
||||
}
|
||||
// reset the index if out of bounds
|
||||
if b.i >= len(b.targets) {
|
||||
b.i = 0
|
||||
|
||||
var i int
|
||||
const lastIdxKey = "_round_robin_last_index"
|
||||
// This request is a retry, start from the index of the previous
|
||||
// target to ensure we don't attempt to retry the request with
|
||||
// the same failed target
|
||||
if c.Get(lastIdxKey) != nil {
|
||||
i = c.Get(lastIdxKey).(int)
|
||||
i++
|
||||
if i >= len(b.targets) {
|
||||
i = 0
|
||||
}
|
||||
} else {
|
||||
// This is a first time request, use the global index
|
||||
if b.i >= len(b.targets) {
|
||||
b.i = 0
|
||||
}
|
||||
i = b.i
|
||||
b.i++
|
||||
}
|
||||
t := b.targets[b.i]
|
||||
b.i++
|
||||
return t
|
||||
|
||||
c.Set(lastIdxKey, i)
|
||||
return b.targets[i]
|
||||
}
|
||||
|
||||
// Proxy returns a Proxy middleware.
|
||||
@ -232,14 +281,26 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
|
||||
// ProxyWithConfig returns a Proxy middleware with config.
|
||||
// See: `Proxy()`
|
||||
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
if config.Balancer == nil {
|
||||
panic("echo: proxy middleware requires balancer")
|
||||
}
|
||||
// Defaults
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultProxyConfig.Skipper
|
||||
}
|
||||
if config.Balancer == nil {
|
||||
panic("echo: proxy middleware requires balancer")
|
||||
if config.RetryFilter == nil {
|
||||
config.RetryFilter = func(c echo.Context, e error) bool {
|
||||
if httpErr, ok := e.(*echo.HTTPError); ok {
|
||||
return httpErr.Code == http.StatusBadGateway
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
if config.ErrorHandler == nil {
|
||||
config.ErrorHandler = func(c echo.Context, err error) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if config.Rewrite != nil {
|
||||
if config.RegexRewrite == nil {
|
||||
config.RegexRewrite = make(map[*regexp.Regexp]string)
|
||||
@ -250,28 +311,17 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
provider, isTargetProvider := config.Balancer.(TargetProvider)
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) (err error) {
|
||||
return func(c echo.Context) error {
|
||||
if config.Skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
|
||||
var tgt *ProxyTarget
|
||||
if isTargetProvider {
|
||||
tgt, err = provider.NextTarget(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
tgt = config.Balancer.Next(c)
|
||||
}
|
||||
c.Set(config.ContextKey, tgt)
|
||||
|
||||
if err := rewriteURL(config.RegexRewrite, req); err != nil {
|
||||
return err
|
||||
return config.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
// Fix header
|
||||
@ -287,19 +337,49 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
|
||||
}
|
||||
|
||||
// Proxy
|
||||
switch {
|
||||
case c.IsWebSocket():
|
||||
proxyRaw(tgt, c).ServeHTTP(res, req)
|
||||
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
|
||||
default:
|
||||
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
|
||||
}
|
||||
if e, ok := c.Get("_error").(error); ok {
|
||||
err = e
|
||||
}
|
||||
retries := config.RetryCount
|
||||
for {
|
||||
var tgt *ProxyTarget
|
||||
var err error
|
||||
if isTargetProvider {
|
||||
tgt, err = provider.NextTarget(c)
|
||||
if err != nil {
|
||||
return config.ErrorHandler(c, err)
|
||||
}
|
||||
} else {
|
||||
tgt = config.Balancer.Next(c)
|
||||
}
|
||||
|
||||
return
|
||||
c.Set(config.ContextKey, tgt)
|
||||
|
||||
//If retrying a failed request, clear any previous errors from
|
||||
//context here so that balancers have the option to check for
|
||||
//errors that occurred using previous target
|
||||
if retries < config.RetryCount {
|
||||
c.Set("_error", nil)
|
||||
}
|
||||
|
||||
// Proxy
|
||||
switch {
|
||||
case c.IsWebSocket():
|
||||
proxyRaw(tgt, c).ServeHTTP(res, req)
|
||||
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
|
||||
default:
|
||||
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
|
||||
}
|
||||
|
||||
err, hasError := c.Get("_error").(error)
|
||||
if !hasError {
|
||||
return nil
|
||||
}
|
||||
|
||||
retry := retries > 0 && config.RetryFilter(c, err)
|
||||
if !retry {
|
||||
return config.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
retries--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -393,6 +394,321 @@ func TestProxyError(t *testing.T) {
|
||||
assert.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
}
|
||||
|
||||
func TestProxyRetries(t *testing.T) {
|
||||
|
||||
newServer := func(res int) (*url.URL, *httptest.Server) {
|
||||
server := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(res)
|
||||
}),
|
||||
)
|
||||
targetURL, _ := url.Parse(server.URL)
|
||||
return targetURL, server
|
||||
}
|
||||
|
||||
targetURL, server := newServer(http.StatusOK)
|
||||
defer server.Close()
|
||||
goodTarget := &ProxyTarget{
|
||||
Name: "Good",
|
||||
URL: targetURL,
|
||||
}
|
||||
|
||||
targetURL, server = newServer(http.StatusBadRequest)
|
||||
defer server.Close()
|
||||
goodTargetWith40X := &ProxyTarget{
|
||||
Name: "Good with 40X",
|
||||
URL: targetURL,
|
||||
}
|
||||
|
||||
targetURL, _ = url.Parse("http://127.0.0.1:27121")
|
||||
badTarget := &ProxyTarget{
|
||||
Name: "Bad",
|
||||
URL: targetURL,
|
||||
}
|
||||
|
||||
alwaysRetryFilter := func(c echo.Context, e error) bool { return true }
|
||||
neverRetryFilter := func(c echo.Context, e error) bool { return false }
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
retryCount int
|
||||
retryFilters []func(c echo.Context, e error) bool
|
||||
targets []*ProxyTarget
|
||||
expectedResponse int
|
||||
}{
|
||||
{
|
||||
name: "retry count 0 does not attempt retry on fail",
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "retry count 1 does not attempt retry on success",
|
||||
retryCount: 1,
|
||||
targets: []*ProxyTarget{
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "retry count 1 does retry on handler return true",
|
||||
retryCount: 1,
|
||||
retryFilters: []func(c echo.Context, e error) bool{
|
||||
alwaysRetryFilter,
|
||||
},
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "retry count 1 does not retry on handler return false",
|
||||
retryCount: 1,
|
||||
retryFilters: []func(c echo.Context, e error) bool{
|
||||
neverRetryFilter,
|
||||
},
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "retry count 2 returns error when no more retries left",
|
||||
retryCount: 2,
|
||||
retryFilters: []func(c echo.Context, e error) bool{
|
||||
alwaysRetryFilter,
|
||||
alwaysRetryFilter,
|
||||
},
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
badTarget,
|
||||
badTarget,
|
||||
goodTarget, //Should never be reached as only 2 retries
|
||||
},
|
||||
expectedResponse: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "retry count 2 returns error when retries left but handler returns false",
|
||||
retryCount: 3,
|
||||
retryFilters: []func(c echo.Context, e error) bool{
|
||||
alwaysRetryFilter,
|
||||
alwaysRetryFilter,
|
||||
neverRetryFilter,
|
||||
},
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
badTarget,
|
||||
badTarget,
|
||||
goodTarget, //Should never be reached as retry handler returns false on 2nd check
|
||||
},
|
||||
expectedResponse: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "retry count 3 succeeds",
|
||||
retryCount: 3,
|
||||
retryFilters: []func(c echo.Context, e error) bool{
|
||||
alwaysRetryFilter,
|
||||
alwaysRetryFilter,
|
||||
alwaysRetryFilter,
|
||||
},
|
||||
targets: []*ProxyTarget{
|
||||
badTarget,
|
||||
badTarget,
|
||||
badTarget,
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "40x responses are not retried",
|
||||
retryCount: 1,
|
||||
targets: []*ProxyTarget{
|
||||
goodTargetWith40X,
|
||||
goodTarget,
|
||||
},
|
||||
expectedResponse: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
retryFilterCall := 0
|
||||
retryFilter := func(c echo.Context, e error) bool {
|
||||
if len(tc.retryFilters) == 0 {
|
||||
assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall))
|
||||
}
|
||||
|
||||
retryFilterCall++
|
||||
|
||||
nextRetryFilter := tc.retryFilters[0]
|
||||
tc.retryFilters = tc.retryFilters[1:]
|
||||
|
||||
return nextRetryFilter(c, e)
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
e.Use(ProxyWithConfig(
|
||||
ProxyConfig{
|
||||
Balancer: NewRoundRobinBalancer(tc.targets),
|
||||
RetryCount: tc.retryCount,
|
||||
RetryFilter: retryFilter,
|
||||
},
|
||||
))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, tc.expectedResponse, rec.Code)
|
||||
if len(tc.retryFilters) > 0 {
|
||||
assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyRetryWithBackendTimeout(t *testing.T) {
|
||||
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.ResponseHeaderTimeout = time.Millisecond * 500
|
||||
|
||||
timeoutBackend := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(1 * time.Second)
|
||||
w.WriteHeader(404)
|
||||
}),
|
||||
)
|
||||
defer timeoutBackend.Close()
|
||||
|
||||
timeoutTargetURL, _ := url.Parse(timeoutBackend.URL)
|
||||
goodBackend := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}),
|
||||
)
|
||||
defer goodBackend.Close()
|
||||
|
||||
goodTargetURL, _ := url.Parse(goodBackend.URL)
|
||||
e := echo.New()
|
||||
e.Use(ProxyWithConfig(
|
||||
ProxyConfig{
|
||||
Transport: transport,
|
||||
Balancer: NewRoundRobinBalancer([]*ProxyTarget{
|
||||
{
|
||||
Name: "Timeout",
|
||||
URL: timeoutTargetURL,
|
||||
},
|
||||
{
|
||||
Name: "Good",
|
||||
URL: goodTargetURL,
|
||||
},
|
||||
}),
|
||||
RetryCount: 1,
|
||||
},
|
||||
))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, 200, rec.Code)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
||||
func TestProxyErrorHandler(t *testing.T) {
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
goodURL, _ := url.Parse(server.URL)
|
||||
defer server.Close()
|
||||
goodTarget := &ProxyTarget{
|
||||
Name: "Good",
|
||||
URL: goodURL,
|
||||
}
|
||||
|
||||
badURL, _ := url.Parse("http://127.0.0.1:27121")
|
||||
badTarget := &ProxyTarget{
|
||||
Name: "Bad",
|
||||
URL: badURL,
|
||||
}
|
||||
|
||||
transformedError := errors.New("a new error")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
target *ProxyTarget
|
||||
errorHandler func(c echo.Context, e error) error
|
||||
expectFinalError func(t *testing.T, err error)
|
||||
}{
|
||||
{
|
||||
name: "Error handler not invoked when request success",
|
||||
target: goodTarget,
|
||||
errorHandler: func(c echo.Context, e error) error {
|
||||
assert.FailNow(t, "error handler should not be invoked")
|
||||
return e
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error handler invoked when request fails",
|
||||
target: badTarget,
|
||||
errorHandler: func(c echo.Context, e error) error {
|
||||
httpErr, ok := e.(*echo.HTTPError)
|
||||
assert.True(t, ok, "expected http error to be passed to handler")
|
||||
assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
|
||||
return transformedError
|
||||
},
|
||||
expectFinalError: func(t *testing.T, err error) {
|
||||
assert.Equal(t, transformedError, err, "transformed error not returned from proxy")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Use(ProxyWithConfig(
|
||||
ProxyConfig{
|
||||
Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}),
|
||||
ErrorHandler: tc.errorHandler,
|
||||
},
|
||||
))
|
||||
|
||||
errorHandlerCalled := false
|
||||
e.HTTPErrorHandler = func(err error, c echo.Context) {
|
||||
errorHandlerCalled = true
|
||||
tc.expectFinalError(t, err)
|
||||
e.DefaultHTTPErrorHandler(err, c)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
if !errorHandlerCalled && tc.expectFinalError != nil {
|
||||
t.Fatalf("error handler was not called")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
|
||||
var timeoutStop sync.WaitGroup
|
||||
timeoutStop.Add(1)
|
||||
|
Loading…
Reference in New Issue
Block a user