mirror of
https://github.com/labstack/echo.git
synced 2025-01-10 00:28:23 +02:00
813 lines
21 KiB
Go
813 lines
21 KiB
Go
// SPDX-License-Identifier: MIT
|
|
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
|
|
|
|
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"regexp"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// Assert expected with url.EscapedPath method to obtain the path.
|
|
func TestProxy(t *testing.T) {
|
|
// Setup
|
|
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprint(w, "target 1")
|
|
}))
|
|
defer t1.Close()
|
|
url1, _ := url.Parse(t1.URL)
|
|
t2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprint(w, "target 2")
|
|
}))
|
|
defer t2.Close()
|
|
url2, _ := url.Parse(t2.URL)
|
|
targets := []*ProxyTarget{
|
|
{
|
|
Name: "target 1",
|
|
URL: url1,
|
|
},
|
|
{
|
|
Name: "target 2",
|
|
URL: url2,
|
|
},
|
|
}
|
|
rb := NewRandomBalancer(nil)
|
|
// must add targets:
|
|
for _, target := range targets {
|
|
assert.True(t, rb.AddTarget(target))
|
|
}
|
|
|
|
// must ignore duplicates:
|
|
for _, target := range targets {
|
|
assert.False(t, rb.AddTarget(target))
|
|
}
|
|
|
|
// Random
|
|
e := echo.New()
|
|
e.Use(Proxy(rb))
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
e.ServeHTTP(rec, req)
|
|
body := rec.Body.String()
|
|
expected := map[string]bool{
|
|
"target 1": true,
|
|
"target 2": true,
|
|
}
|
|
assert.Condition(t, func() bool {
|
|
return expected[body]
|
|
})
|
|
|
|
for _, target := range targets {
|
|
assert.True(t, rb.RemoveTarget(target.Name))
|
|
}
|
|
|
|
assert.False(t, rb.RemoveTarget("unknown target"))
|
|
|
|
// Round-robin
|
|
rrb := NewRoundRobinBalancer(targets)
|
|
e = echo.New()
|
|
e.Use(Proxy(rrb))
|
|
|
|
rec = httptest.NewRecorder()
|
|
e.ServeHTTP(rec, req)
|
|
body = rec.Body.String()
|
|
assert.Equal(t, "target 1", body)
|
|
|
|
rec = httptest.NewRecorder()
|
|
e.ServeHTTP(rec, req)
|
|
body = rec.Body.String()
|
|
assert.Equal(t, "target 2", body)
|
|
|
|
// ModifyResponse
|
|
e = echo.New()
|
|
e.Use(ProxyWithConfig(ProxyConfig{
|
|
Balancer: rrb,
|
|
ModifyResponse: func(res *http.Response) error {
|
|
res.Body = io.NopCloser(bytes.NewBuffer([]byte("modified")))
|
|
res.Header.Set("X-Modified", "1")
|
|
return nil
|
|
},
|
|
}))
|
|
|
|
rec = httptest.NewRecorder()
|
|
e.ServeHTTP(rec, req)
|
|
assert.Equal(t, "modified", rec.Body.String())
|
|
assert.Equal(t, "1", rec.Header().Get("X-Modified"))
|
|
|
|
// ProxyTarget is set in context
|
|
contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) (err error) {
|
|
next(c)
|
|
assert.Contains(t, targets, c.Get("target"), "target is not set in context")
|
|
return nil
|
|
}
|
|
}
|
|
rrb1 := NewRoundRobinBalancer(targets)
|
|
|
|
e = echo.New()
|
|
e.Use(contextObserver)
|
|
e.Use(Proxy(rrb1))
|
|
rec = httptest.NewRecorder()
|
|
e.ServeHTTP(rec, req)
|
|
}
|
|
|
|
type testProvider struct {
|
|
commonBalancer
|
|
target *ProxyTarget
|
|
err error
|
|
}
|
|
|
|
func (p *testProvider) Next(c echo.Context) *ProxyTarget {
|
|
return &ProxyTarget{}
|
|
}
|
|
|
|
func (p *testProvider) NextTarget(c echo.Context) (*ProxyTarget, error) {
|
|
return p.target, p.err
|
|
}
|
|
|
|
func TestTargetProvider(t *testing.T) {
|
|
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprint(w, "target 1")
|
|
}))
|
|
defer t1.Close()
|
|
url1, _ := url.Parse(t1.URL)
|
|
|
|
e := echo.New()
|
|
tp := &testProvider{}
|
|
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
|
|
e.Use(Proxy(tp))
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
e.ServeHTTP(rec, req)
|
|
body := rec.Body.String()
|
|
assert.Equal(t, "target 1", body)
|
|
}
|
|
|
|
func TestFailNextTarget(t *testing.T) {
|
|
url1, err := url.Parse("http://dummy:8080")
|
|
assert.Nil(t, err)
|
|
|
|
e := echo.New()
|
|
tp := &testProvider{}
|
|
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
|
|
tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
|
|
|
|
e.Use(Proxy(tp))
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
e.ServeHTTP(rec, req)
|
|
body := rec.Body.String()
|
|
assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
|
|
}
|
|
|
|
func TestProxyRealIPHeader(t *testing.T) {
|
|
// Setup
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
|
defer upstream.Close()
|
|
url, _ := url.Parse(upstream.URL)
|
|
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
|
|
e := echo.New()
|
|
e.Use(Proxy(rrb))
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
remoteAddrIP, _, _ := net.SplitHostPort(req.RemoteAddr)
|
|
realIPHeaderIP := "203.0.113.1"
|
|
extractedRealIP := "203.0.113.10"
|
|
tests := []*struct {
|
|
hasRealIPheader bool
|
|
hasIPExtractor bool
|
|
expectedXRealIP string
|
|
}{
|
|
{false, false, remoteAddrIP},
|
|
{false, true, extractedRealIP},
|
|
{true, false, realIPHeaderIP},
|
|
{true, true, extractedRealIP},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
if tt.hasRealIPheader {
|
|
req.Header.Set(echo.HeaderXRealIP, realIPHeaderIP)
|
|
} else {
|
|
req.Header.Del(echo.HeaderXRealIP)
|
|
}
|
|
if tt.hasIPExtractor {
|
|
e.IPExtractor = func(*http.Request) string {
|
|
return extractedRealIP
|
|
}
|
|
} else {
|
|
e.IPExtractor = nil
|
|
}
|
|
e.ServeHTTP(rec, req)
|
|
assert.Equal(t, tt.expectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor)
|
|
}
|
|
}
|
|
|
|
func TestProxyRewrite(t *testing.T) {
|
|
var testCases = []struct {
|
|
whenPath string
|
|
expectProxiedURI string
|
|
expectStatus int
|
|
}{
|
|
{
|
|
whenPath: "/api/users",
|
|
expectProxiedURI: "/users",
|
|
expectStatus: http.StatusOK,
|
|
},
|
|
{
|
|
whenPath: "/js/main.js",
|
|
expectProxiedURI: "/public/javascripts/main.js",
|
|
expectStatus: http.StatusOK,
|
|
},
|
|
{
|
|
whenPath: "/old",
|
|
expectProxiedURI: "/new",
|
|
expectStatus: http.StatusOK,
|
|
},
|
|
{
|
|
whenPath: "/users/jack/orders/1",
|
|
expectProxiedURI: "/user/jack/order/1",
|
|
expectStatus: http.StatusOK,
|
|
},
|
|
{
|
|
whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
|
|
expectProxiedURI: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
|
|
expectStatus: http.StatusOK,
|
|
},
|
|
{ // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped when proxying request
|
|
whenPath: "/api/new users",
|
|
expectProxiedURI: "/new%20users",
|
|
expectStatus: http.StatusOK,
|
|
},
|
|
{ // query params should be proxied and not be modified
|
|
whenPath: "/api/users?limit=10",
|
|
expectProxiedURI: "/users?limit=10",
|
|
expectStatus: http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.whenPath, func(t *testing.T) {
|
|
receivedRequestURI := make(chan string, 1)
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server
|
|
// we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic
|
|
// if original request had `%2F` we should not magically decode it to `/` as it would change what was requested
|
|
receivedRequestURI <- r.RequestURI
|
|
}))
|
|
defer upstream.Close()
|
|
serverURL, _ := url.Parse(upstream.URL)
|
|
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: serverURL}})
|
|
|
|
// Rewrite
|
|
e := echo.New()
|
|
e.Use(ProxyWithConfig(ProxyConfig{
|
|
Balancer: rrb,
|
|
Rewrite: map[string]string{
|
|
"/old": "/new",
|
|
"/api/*": "/$1",
|
|
"/js/*": "/public/javascripts/$1",
|
|
"/users/*/orders/*": "/user/$1/order/$2",
|
|
},
|
|
}))
|
|
|
|
targetURL, _ := serverURL.Parse(tc.whenPath)
|
|
req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
e.ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, tc.expectStatus, rec.Code)
|
|
actualRequestURI := <-receivedRequestURI
|
|
assert.Equal(t, tc.expectProxiedURI, actualRequestURI)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProxyRewriteRegex(t *testing.T) {
|
|
// Setup
|
|
receivedRequestURI := make(chan string, 1)
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server
|
|
// we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic
|
|
// if original request had `%2F` we should not magically decode it to `/` as it would change what was requested
|
|
receivedRequestURI <- r.RequestURI
|
|
}))
|
|
defer upstream.Close()
|
|
tmpUrL, _ := url.Parse(upstream.URL)
|
|
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: tmpUrL}})
|
|
|
|
// Rewrite
|
|
e := echo.New()
|
|
e.Use(ProxyWithConfig(ProxyConfig{
|
|
Balancer: rrb,
|
|
Rewrite: map[string]string{
|
|
"^/a/*": "/v1/$1",
|
|
"^/b/*/c/*": "/v2/$2/$1",
|
|
"^/c/*/*": "/v3/$2",
|
|
},
|
|
RegexRewrite: map[*regexp.Regexp]string{
|
|
regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
|
|
regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
|
|
},
|
|
}))
|
|
|
|
testCases := []struct {
|
|
requestPath string
|
|
statusCode int
|
|
expectPath string
|
|
}{
|
|
{"/unmatched", http.StatusOK, "/unmatched"},
|
|
{"/a/test", http.StatusOK, "/v1/test"},
|
|
{"/b/foo/c/bar/baz", http.StatusOK, "/v2/bar/baz/foo"},
|
|
{"/c/ignore/test", http.StatusOK, "/v3/test"},
|
|
{"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"},
|
|
{"/x/ignore/test", http.StatusOK, "/v4/test"},
|
|
{"/y/foo/bar", http.StatusOK, "/v5/bar/foo"},
|
|
// NB: fragment is not added by golang httputil.NewSingleHostReverseProxy implementation
|
|
// $2 = `bar?q=1#frag`, $1 = `foo`. replaced uri = `/v5/bar?q=1#frag/foo` but httputil.NewSingleHostReverseProxy does not send `#frag/foo` (currently)
|
|
{"/y/foo/bar?q=1#frag", http.StatusOK, "/v5/bar?q=1"},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.requestPath, func(t *testing.T) {
|
|
targetURL, _ := url.Parse(tc.requestPath)
|
|
req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
e.ServeHTTP(rec, req)
|
|
|
|
actualRequestURI := <-receivedRequestURI
|
|
assert.Equal(t, tc.expectPath, actualRequestURI)
|
|
assert.Equal(t, tc.statusCode, rec.Code)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProxyError(t *testing.T) {
|
|
// Setup
|
|
url1, _ := url.Parse("http://127.0.0.1:27121")
|
|
url2, _ := url.Parse("http://127.0.0.1:27122")
|
|
|
|
targets := []*ProxyTarget{
|
|
{
|
|
Name: "target 1",
|
|
URL: url1,
|
|
},
|
|
{
|
|
Name: "target 2",
|
|
URL: url2,
|
|
},
|
|
}
|
|
rb := NewRandomBalancer(nil)
|
|
// must add targets:
|
|
for _, target := range targets {
|
|
assert.True(t, rb.AddTarget(target))
|
|
}
|
|
|
|
// must ignore duplicates:
|
|
for _, target := range targets {
|
|
assert.False(t, rb.AddTarget(target))
|
|
}
|
|
|
|
// Random
|
|
e := echo.New()
|
|
e.Use(Proxy(rb))
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
|
|
// Remote unreachable
|
|
rec := httptest.NewRecorder()
|
|
req.URL.Path = "/api/users"
|
|
e.ServeHTTP(rec, req)
|
|
assert.Equal(t, "/api/users", req.URL.Path)
|
|
assert.Equal(t, http.StatusBadGateway, rec.Code)
|
|
}
|
|
|
|
func TestProxyRetries(t *testing.T) {
|
|
|
|
newServer := func(res int) (*url.URL, *httptest.Server) {
|
|
server := httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(res)
|
|
}),
|
|
)
|
|
targetURL, _ := url.Parse(server.URL)
|
|
return targetURL, server
|
|
}
|
|
|
|
targetURL, server := newServer(http.StatusOK)
|
|
defer server.Close()
|
|
goodTarget := &ProxyTarget{
|
|
Name: "Good",
|
|
URL: targetURL,
|
|
}
|
|
|
|
targetURL, server = newServer(http.StatusBadRequest)
|
|
defer server.Close()
|
|
goodTargetWith40X := &ProxyTarget{
|
|
Name: "Good with 40X",
|
|
URL: targetURL,
|
|
}
|
|
|
|
targetURL, _ = url.Parse("http://127.0.0.1:27121")
|
|
badTarget := &ProxyTarget{
|
|
Name: "Bad",
|
|
URL: targetURL,
|
|
}
|
|
|
|
alwaysRetryFilter := func(c echo.Context, e error) bool { return true }
|
|
neverRetryFilter := func(c echo.Context, e error) bool { return false }
|
|
|
|
testCases := []struct {
|
|
name string
|
|
retryCount int
|
|
retryFilters []func(c echo.Context, e error) bool
|
|
targets []*ProxyTarget
|
|
expectedResponse int
|
|
}{
|
|
{
|
|
name: "retry count 0 does not attempt retry on fail",
|
|
targets: []*ProxyTarget{
|
|
badTarget,
|
|
goodTarget,
|
|
},
|
|
expectedResponse: http.StatusBadGateway,
|
|
},
|
|
{
|
|
name: "retry count 1 does not attempt retry on success",
|
|
retryCount: 1,
|
|
targets: []*ProxyTarget{
|
|
goodTarget,
|
|
},
|
|
expectedResponse: http.StatusOK,
|
|
},
|
|
{
|
|
name: "retry count 1 does retry on handler return true",
|
|
retryCount: 1,
|
|
retryFilters: []func(c echo.Context, e error) bool{
|
|
alwaysRetryFilter,
|
|
},
|
|
targets: []*ProxyTarget{
|
|
badTarget,
|
|
goodTarget,
|
|
},
|
|
expectedResponse: http.StatusOK,
|
|
},
|
|
{
|
|
name: "retry count 1 does not retry on handler return false",
|
|
retryCount: 1,
|
|
retryFilters: []func(c echo.Context, e error) bool{
|
|
neverRetryFilter,
|
|
},
|
|
targets: []*ProxyTarget{
|
|
badTarget,
|
|
goodTarget,
|
|
},
|
|
expectedResponse: http.StatusBadGateway,
|
|
},
|
|
{
|
|
name: "retry count 2 returns error when no more retries left",
|
|
retryCount: 2,
|
|
retryFilters: []func(c echo.Context, e error) bool{
|
|
alwaysRetryFilter,
|
|
alwaysRetryFilter,
|
|
},
|
|
targets: []*ProxyTarget{
|
|
badTarget,
|
|
badTarget,
|
|
badTarget,
|
|
goodTarget, //Should never be reached as only 2 retries
|
|
},
|
|
expectedResponse: http.StatusBadGateway,
|
|
},
|
|
{
|
|
name: "retry count 2 returns error when retries left but handler returns false",
|
|
retryCount: 3,
|
|
retryFilters: []func(c echo.Context, e error) bool{
|
|
alwaysRetryFilter,
|
|
alwaysRetryFilter,
|
|
neverRetryFilter,
|
|
},
|
|
targets: []*ProxyTarget{
|
|
badTarget,
|
|
badTarget,
|
|
badTarget,
|
|
goodTarget, //Should never be reached as retry handler returns false on 2nd check
|
|
},
|
|
expectedResponse: http.StatusBadGateway,
|
|
},
|
|
{
|
|
name: "retry count 3 succeeds",
|
|
retryCount: 3,
|
|
retryFilters: []func(c echo.Context, e error) bool{
|
|
alwaysRetryFilter,
|
|
alwaysRetryFilter,
|
|
alwaysRetryFilter,
|
|
},
|
|
targets: []*ProxyTarget{
|
|
badTarget,
|
|
badTarget,
|
|
badTarget,
|
|
goodTarget,
|
|
},
|
|
expectedResponse: http.StatusOK,
|
|
},
|
|
{
|
|
name: "40x responses are not retried",
|
|
retryCount: 1,
|
|
targets: []*ProxyTarget{
|
|
goodTargetWith40X,
|
|
goodTarget,
|
|
},
|
|
expectedResponse: http.StatusBadRequest,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
|
|
retryFilterCall := 0
|
|
retryFilter := func(c echo.Context, e error) bool {
|
|
if len(tc.retryFilters) == 0 {
|
|
assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall))
|
|
}
|
|
|
|
retryFilterCall++
|
|
|
|
nextRetryFilter := tc.retryFilters[0]
|
|
tc.retryFilters = tc.retryFilters[1:]
|
|
|
|
return nextRetryFilter(c, e)
|
|
}
|
|
|
|
e := echo.New()
|
|
e.Use(ProxyWithConfig(
|
|
ProxyConfig{
|
|
Balancer: NewRoundRobinBalancer(tc.targets),
|
|
RetryCount: tc.retryCount,
|
|
RetryFilter: retryFilter,
|
|
},
|
|
))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
e.ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, tc.expectedResponse, rec.Code)
|
|
if len(tc.retryFilters) > 0 {
|
|
assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters)))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProxyRetryWithBackendTimeout(t *testing.T) {
|
|
|
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
transport.ResponseHeaderTimeout = time.Millisecond * 500
|
|
|
|
timeoutBackend := httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
time.Sleep(1 * time.Second)
|
|
w.WriteHeader(404)
|
|
}),
|
|
)
|
|
defer timeoutBackend.Close()
|
|
|
|
timeoutTargetURL, _ := url.Parse(timeoutBackend.URL)
|
|
goodBackend := httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
}),
|
|
)
|
|
defer goodBackend.Close()
|
|
|
|
goodTargetURL, _ := url.Parse(goodBackend.URL)
|
|
e := echo.New()
|
|
e.Use(ProxyWithConfig(
|
|
ProxyConfig{
|
|
Transport: transport,
|
|
Balancer: NewRoundRobinBalancer([]*ProxyTarget{
|
|
{
|
|
Name: "Timeout",
|
|
URL: timeoutTargetURL,
|
|
},
|
|
{
|
|
Name: "Good",
|
|
URL: goodTargetURL,
|
|
},
|
|
}),
|
|
RetryCount: 1,
|
|
},
|
|
))
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 20; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
e.ServeHTTP(rec, req)
|
|
assert.Equal(t, 200, rec.Code)
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
}
|
|
|
|
func TestProxyErrorHandler(t *testing.T) {
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
goodURL, _ := url.Parse(server.URL)
|
|
defer server.Close()
|
|
goodTarget := &ProxyTarget{
|
|
Name: "Good",
|
|
URL: goodURL,
|
|
}
|
|
|
|
badURL, _ := url.Parse("http://127.0.0.1:27121")
|
|
badTarget := &ProxyTarget{
|
|
Name: "Bad",
|
|
URL: badURL,
|
|
}
|
|
|
|
transformedError := errors.New("a new error")
|
|
|
|
testCases := []struct {
|
|
name string
|
|
target *ProxyTarget
|
|
errorHandler func(c echo.Context, e error) error
|
|
expectFinalError func(t *testing.T, err error)
|
|
}{
|
|
{
|
|
name: "Error handler not invoked when request success",
|
|
target: goodTarget,
|
|
errorHandler: func(c echo.Context, e error) error {
|
|
assert.FailNow(t, "error handler should not be invoked")
|
|
return e
|
|
},
|
|
},
|
|
{
|
|
name: "Error handler invoked when request fails",
|
|
target: badTarget,
|
|
errorHandler: func(c echo.Context, e error) error {
|
|
httpErr, ok := e.(*echo.HTTPError)
|
|
assert.True(t, ok, "expected http error to be passed to handler")
|
|
assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
|
|
return transformedError
|
|
},
|
|
expectFinalError: func(t *testing.T, err error) {
|
|
assert.Equal(t, transformedError, err, "transformed error not returned from proxy")
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
e := echo.New()
|
|
e.Use(ProxyWithConfig(
|
|
ProxyConfig{
|
|
Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}),
|
|
ErrorHandler: tc.errorHandler,
|
|
},
|
|
))
|
|
|
|
errorHandlerCalled := false
|
|
e.HTTPErrorHandler = func(err error, c echo.Context) {
|
|
errorHandlerCalled = true
|
|
tc.expectFinalError(t, err)
|
|
e.DefaultHTTPErrorHandler(err, c)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
e.ServeHTTP(rec, req)
|
|
|
|
if !errorHandlerCalled && tc.expectFinalError != nil {
|
|
t.Fatalf("error handler was not called")
|
|
}
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
|
|
var timeoutStop sync.WaitGroup
|
|
timeoutStop.Add(1)
|
|
HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
timeoutStop.Wait() // wait until we have canceled the request
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer HTTPTarget.Close()
|
|
targetURL, _ := url.Parse(HTTPTarget.URL)
|
|
target := &ProxyTarget{
|
|
Name: "target",
|
|
URL: targetURL,
|
|
}
|
|
rb := NewRandomBalancer(nil)
|
|
assert.True(t, rb.AddTarget(target))
|
|
e := echo.New()
|
|
e.Use(Proxy(rb))
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
ctx, cancel := context.WithCancel(req.Context())
|
|
req = req.WithContext(ctx)
|
|
go func() {
|
|
time.Sleep(10 * time.Millisecond)
|
|
cancel()
|
|
}()
|
|
e.ServeHTTP(rec, req)
|
|
timeoutStop.Done()
|
|
assert.Equal(t, 499, rec.Code)
|
|
}
|
|
|
|
// Assert balancer with empty targets does return `nil` on `Next()`
|
|
func TestProxyBalancerWithNoTargets(t *testing.T) {
|
|
rb := NewRandomBalancer(nil)
|
|
assert.Nil(t, rb.Next(nil))
|
|
|
|
rrb := NewRoundRobinBalancer([]*ProxyTarget{})
|
|
assert.Nil(t, rrb.Next(nil))
|
|
}
|
|
|
|
type testContextKey string
|
|
|
|
type customBalancer struct {
|
|
target *ProxyTarget
|
|
}
|
|
|
|
func (b *customBalancer) AddTarget(target *ProxyTarget) bool {
|
|
return false
|
|
}
|
|
|
|
func (b *customBalancer) RemoveTarget(name string) bool {
|
|
return false
|
|
}
|
|
|
|
func (b *customBalancer) Next(c echo.Context) *ProxyTarget {
|
|
ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER")
|
|
c.SetRequest(c.Request().WithContext(ctx))
|
|
return b.target
|
|
}
|
|
|
|
func TestModifyResponseUseContext(t *testing.T) {
|
|
server := httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
}),
|
|
)
|
|
defer server.Close()
|
|
|
|
targetURL, _ := url.Parse(server.URL)
|
|
e := echo.New()
|
|
e.Use(ProxyWithConfig(
|
|
ProxyConfig{
|
|
Balancer: &customBalancer{
|
|
target: &ProxyTarget{
|
|
Name: "tst",
|
|
URL: targetURL,
|
|
},
|
|
},
|
|
RetryCount: 1,
|
|
ModifyResponse: func(res *http.Response) error {
|
|
val := res.Request.Context().Value(testContextKey("FROM_BALANCER"))
|
|
if valStr, ok := val.(string); ok {
|
|
res.Header.Set("FROM_BALANCER", valStr)
|
|
}
|
|
return nil
|
|
},
|
|
},
|
|
))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
e.ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "OK", rec.Body.String())
|
|
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
|
|
}
|