1
0
mirror of https://github.com/umputun/reproxy.git synced 2025-02-16 18:34:30 +02:00

add middleware to optionally allow requests from giving ips/ranges

add new remote param to docker and file providers

lint: http nil body

add support of remote ips to consul provider

local implementation of onlyfrom middleware

lint: missing comment

make proxy tests more readable

preffer public IP if any forwwarded
This commit is contained in:
Umputun 2023-11-25 13:54:14 -06:00
parent 36a5443378
commit a896f08eec
16 changed files with 531 additions and 142 deletions

View File

@ -37,6 +37,7 @@ type URLMapper struct {
PingURL string
MatchType MatchType
RedirectType RedirectType
OnlyFromIPs []string
AssetsLocation string // local FS root location
AssetsWebRoot string // web root location
@ -484,16 +485,6 @@ func (s *Service) mergeEvents(ctx context.Context, chs ...<-chan ProviderID) <-c
return out
}
// Contains checks if the input string (e) in the given slice
func Contains(e string, s []string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
// IsAlive indicates whether mapper destination is alive
func (m URLMapper) IsAlive() bool {
return !m.dead
@ -515,3 +506,24 @@ func (m URLMapper) ping() (string, error) {
return "", err
}
// Contains checks if the input string (e) in the given slice
func Contains(e string, s []string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
// ParseOnlyFrom parses comma separated list of IPs
func ParseOnlyFrom(s string) (res []string) {
if s == "" {
return []string{}
}
for _, v := range strings.Split(s, ",") {
res = append(res, strings.TrimSpace(v))
}
return res
}

View File

@ -39,7 +39,7 @@ func TestService_Run(t *testing.T) {
ListFunc: func() ([]URLMapper, error) {
return []URLMapper{
{Server: "localhost", SrcMatch: *regexp.MustCompile("/api/svc3/xyz"),
Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker},
Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker, OnlyFromIPs: []string{"127.0.0.1"}},
}, nil
},
}
@ -66,6 +66,7 @@ func TestService_Run(t *testing.T) {
assert.Equal(t, "localhost", mappers[0].Server)
assert.Equal(t, "/api/svc3/xyz", mappers[0].SrcMatch.String())
assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", mappers[0].Dst)
assert.Equal(t, []string{"127.0.0.1"}, mappers[0].OnlyFromIPs)
assert.Equal(t, 1, len(p1.EventsCalls()))
assert.Equal(t, 1, len(p2.EventsCalls()))
@ -104,7 +105,8 @@ func TestService_Match(t *testing.T) {
},
ListFunc: func() ([]URLMapper, error) {
return []URLMapper{
{SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker},
{SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), Dst: "http://127.0.0.3:8080/blah3/xyz",
OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}, ProviderID: PIDocker},
{SrcMatch: *regexp.MustCompile("/web"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic,
AssetsWebRoot: "/web", AssetsLocation: "/var/web"},
{SrcMatch: *regexp.MustCompile("/www/"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic,
@ -131,9 +133,11 @@ func TestService_Match(t *testing.T) {
res Matches
}{
{"example.com", "/api/svc3/xyz/something", Matches{MTProxy, []MatchedRoute{
{Destination: "http://127.0.0.3:8080/blah3/xyz/something", Alive: true}}}},
{Destination: "http://127.0.0.3:8080/blah3/xyz/something", Alive: true,
Mapper: URLMapper{OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}}}}}},
{"example.com", "/api/svc3/xyz", Matches{MTProxy, []MatchedRoute{{
Destination: "http://127.0.0.3:8080/blah3/xyz", Alive: true}}}},
Destination: "http://127.0.0.3:8080/blah3/xyz", Alive: true,
Mapper: URLMapper{OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}}}}}},
{"abc.example.com", "/api/svc1/1234", Matches{MTProxy, []MatchedRoute{
{Destination: "http://127.0.0.1:8080/blah1/1234", Alive: true}}}},
{"zzz.example.com", "/aaa/api/svc1/1234", Matches{MTProxy, nil}},
@ -167,6 +171,7 @@ func TestService_Match(t *testing.T) {
for i := 0; i < len(res.Routes); i++ {
assert.Equal(t, tt.res.Routes[i].Alive, res.Routes[i].Alive)
assert.Equal(t, tt.res.Routes[i].Destination, res.Routes[i].Destination)
assert.Equal(t, tt.res.Routes[i].Mapper.OnlyFromIPs, res.Routes[i].Mapper.OnlyFromIPs)
}
assert.Equal(t, tt.res.MatchType, res.MatchType)
})
@ -608,3 +613,39 @@ func TestCheckHealth(t *testing.T) {
assert.NoError(t, res[ts.URL])
assert.NoError(t, res[ts2.URL])
}
func TestParseOnlyFrom(t *testing.T) {
tbl := []struct {
name string
input string
expected []string
}{
{
name: "empty string",
input: "",
expected: []string{},
},
{
name: "single IP",
input: "192.168.1.1",
expected: []string{"192.168.1.1"},
},
{
name: "multiple IPs",
input: "192.168.1.1, 192.168.1.2, 192.168.1.3, 10.0.0.0/16",
expected: []string{"192.168.1.1", "192.168.1.2", "192.168.1.3", "10.0.0.0/16"},
},
{
name: "multiple IPs with extra spaces",
input: " 192.168.1.1 , 192.168.1.2 , 192.168.1.3 ",
expected: []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"},
},
}
for _, tt := range tbl {
t.Run(tt.name, func(t *testing.T) {
result := ParseOnlyFrom(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@ -3,12 +3,13 @@ package consulcatalog
import (
"context"
"fmt"
"github.com/umputun/reproxy/app/discovery"
"log"
"regexp"
"sort"
"strings"
"time"
"github.com/umputun/reproxy/app/discovery"
)
//go:generate moq -out consul_client_mock.go -skip-ensure -fmt goimports . ConsulClient
@ -139,6 +140,7 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) {
destURL := fmt.Sprintf("http://%s:%d/$1", c.ServiceAddress, c.ServicePort)
pingURL := fmt.Sprintf("http://%s:%d/ping", c.ServiceAddress, c.ServicePort)
server := "*"
onlyFrom := []string{}
if v, ok := c.Labels["reproxy.enabled"]; ok && (v == "true" || v == "yes" || v == "1") {
enabled = true
@ -159,6 +161,10 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) {
server = v
}
if v, ok := c.Labels["reproxy.remote"]; ok {
onlyFrom = discovery.ParseOnlyFrom(v)
}
if v, ok := c.Labels["reproxy.ping"]; ok {
enabled = true
pingURL = fmt.Sprintf("http://%s:%d%s", c.ServiceAddress, c.ServicePort, v)
@ -177,7 +183,7 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) {
// server label may have multiple, comma separated servers
for _, srv := range strings.Split(server, ",") {
res = append(res, discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL,
PingURL: pingURL, ProviderID: discovery.PIConsulCatalog})
OnlyFromIPs: onlyFrom, PingURL: pingURL, ProviderID: discovery.PIConsulCatalog})
}
}

View File

@ -3,12 +3,14 @@ package consulcatalog
import (
"context"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/umputun/reproxy/app/discovery"
"sort"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/umputun/reproxy/app/discovery"
)
func TestNew(t *testing.T) {
@ -62,7 +64,8 @@ func TestConsulCatalog_List(t *testing.T) {
ServiceAddress: "addr3",
ServicePort: 3000,
Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1",
"reproxy.server": "example.com,domain.com", "reproxy.ping": "/ping", "reproxy.enabled": "yes"},
"reproxy.server": "example.com,domain.com", "reproxy.ping": "/ping",
"reproxy.enabled": "yes", "reproxy.remote": "127.0.0.1, 192.168.1.0/24"},
},
{
ServiceID: "id4",
@ -91,21 +94,25 @@ func TestConsulCatalog_List(t *testing.T) {
assert.Equal(t, "http://addr3:3000/blah/$1", res[0].Dst)
assert.Equal(t, "example.com", res[0].Server)
assert.Equal(t, "http://addr3:3000/ping", res[0].PingURL)
assert.Equal(t, []string{"127.0.0.1", "192.168.1.0/24"}, res[0].OnlyFromIPs)
assert.Equal(t, "^/api/123/(.*)", res[1].SrcMatch.String())
assert.Equal(t, "http://addr3:3000/blah/$1", res[1].Dst)
assert.Equal(t, "domain.com", res[1].Server)
assert.Equal(t, "http://addr3:3000/ping", res[1].PingURL)
assert.Equal(t, []string{"127.0.0.1", "192.168.1.0/24"}, res[1].OnlyFromIPs)
assert.Equal(t, "^/(.*)", res[2].SrcMatch.String())
assert.Equal(t, "http://addr44:4000/$1", res[2].Dst)
assert.Equal(t, "http://addr44:4000/ping", res[2].PingURL)
assert.Equal(t, "*", res[2].Server)
assert.Equal(t, []string{}, res[2].OnlyFromIPs)
assert.Equal(t, "^/(.*)", res[3].SrcMatch.String())
assert.Equal(t, "http://addr2:2000/$1", res[3].Dst)
assert.Equal(t, "http://addr2:2000/ping", res[3].PingURL)
assert.Equal(t, "*", res[3].Server)
assert.Equal(t, []string{}, res[3].OnlyFromIPs)
}
func TestConsulCatalog_serviceListWasChanged(t *testing.T) {

View File

@ -103,6 +103,7 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper)
// defaults
destURL, pingURL, server := fmt.Sprintf("http://%s:%d/$1", c.IP, port), fmt.Sprintf("http://%s:%d/ping", c.IP, port), "*"
assetsWebRoot, assetsLocation, assetsSPA := "", "", false
onlyFrom := []string{}
if d.AutoAPI && n == 0 {
enabled = true
@ -133,6 +134,10 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper)
server = v
}
if v, ok := d.labelN(c.Labels, n, "remote"); ok {
onlyFrom = discovery.ParseOnlyFrom(v)
}
if v, ok := d.labelN(c.Labels, n, "ping"); ok {
enabled = true
if strings.HasPrefix(v, "http://") || strings.HasPrefix(v, "https://") {
@ -171,7 +176,7 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper)
// docker server label may have multiple, comma separated servers
for _, srv := range strings.Split(server, ",") {
mp := discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL,
PingURL: pingURL, ProviderID: discovery.PIDocker, MatchType: discovery.MTProxy}
PingURL: pingURL, OnlyFromIPs: onlyFrom, ProviderID: discovery.PIDocker, MatchType: discovery.MTProxy}
// for assets we add the second proxy mapping only if explicitly requested
if assetsWebRoot != "" && explicit {

View File

@ -30,7 +30,7 @@ func TestDocker_List(t *testing.T) {
{
Name: "c1", State: "running", IP: "127.0.0.2", Ports: []int{12345},
Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1",
"reproxy.server": "example.com", "reproxy.ping": "/ping"},
"reproxy.server": "example.com", "reproxy.ping": "/ping", "reproxy.remote": "192.168.1.0/24, 127.0.0.1"},
},
{
Name: "c1", State: "running", IP: "127.0.0.21", Ports: []int{12345},
@ -64,21 +64,25 @@ func TestDocker_List(t *testing.T) {
assert.Equal(t, "http://127.0.0.2:12345/blah/$1", res[0].Dst)
assert.Equal(t, "example.com", res[0].Server)
assert.Equal(t, "http://127.0.0.2:12345/ping", res[0].PingURL)
assert.Equal(t, []string{"192.168.1.0/24", "127.0.0.1"}, res[0].OnlyFromIPs)
assert.Equal(t, "^/api/90/(.*)", res[1].SrcMatch.String())
assert.Equal(t, "http://example.com/blah/$1", res[1].Dst)
assert.Equal(t, "https://example.com//ping", res[1].PingURL)
assert.Equal(t, "example.com", res[1].Server)
assert.Equal(t, []string{}, res[1].OnlyFromIPs)
assert.Equal(t, "^/c2/(.*)", res[2].SrcMatch.String())
assert.Equal(t, "http://127.0.0.3:12346/$1", res[2].Dst)
assert.Equal(t, "http://127.0.0.3:12346/ping", res[2].PingURL)
assert.Equal(t, "*", res[2].Server)
assert.Equal(t, []string{}, res[2].OnlyFromIPs)
assert.Equal(t, "^/a/(.*)", res[3].SrcMatch.String())
assert.Equal(t, "http://127.0.0.2:12348/a/$1", res[3].Dst)
assert.Equal(t, "http://127.0.0.2:12348/ping", res[3].PingURL)
assert.Equal(t, "example.com", res[3].Server)
assert.Equal(t, []string{}, res[3].OnlyFromIPs)
}
func TestDocker_ListMulti(t *testing.T) {

View File

@ -84,6 +84,7 @@ func (d *File) List() (res []discovery.URLMapper, err error) {
Ping string `yaml:"ping"`
AssetsEnabled bool `yaml:"assets"`
AssetsSPA bool `yaml:"spa"`
OnlyFrom string `yaml:"remote"`
}
fh, err := os.Open(d.FileName)
if err != nil {
@ -106,12 +107,13 @@ func (d *File) List() (res []discovery.URLMapper, err error) {
srv = "*"
}
mapper := discovery.URLMapper{
Server: srv,
SrcMatch: *rx,
Dst: f.Dest,
PingURL: f.Ping,
ProviderID: discovery.PIFile,
MatchType: discovery.MTProxy,
Server: srv,
SrcMatch: *rx,
Dst: f.Dest,
PingURL: f.Ping,
ProviderID: discovery.PIFile,
MatchType: discovery.MTProxy,
OnlyFromIPs: discovery.ParseOnlyFrom(f.OnlyFrom),
}
if f.AssetsEnabled || f.AssetsSPA {
mapper.MatchType = discovery.MTStatic

View File

@ -113,18 +113,21 @@ func TestFile_List(t *testing.T) {
assert.Equal(t, "", res[0].PingURL)
assert.Equal(t, "srv.example.com", res[0].Server)
assert.Equal(t, discovery.MTProxy, res[0].MatchType)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
assert.Equal(t, "^/api/svc1/(.*)", res[1].SrcMatch.String())
assert.Equal(t, "http://127.0.0.1:8080/blah1/$1", res[1].Dst)
assert.Equal(t, "", res[1].PingURL)
assert.Equal(t, "*", res[1].Server)
assert.Equal(t, discovery.MTProxy, res[1].MatchType)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
assert.Equal(t, "/api/svc3/xyz", res[2].SrcMatch.String())
assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", res[2].Dst)
assert.Equal(t, "http://127.0.0.3:8080/ping", res[2].PingURL)
assert.Equal(t, "*", res[2].Server)
assert.Equal(t, discovery.MTProxy, res[2].MatchType)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
assert.Equal(t, "/web/", res[3].SrcMatch.String())
assert.Equal(t, "/var/web", res[3].Dst)
@ -132,6 +135,7 @@ func TestFile_List(t *testing.T) {
assert.Equal(t, "*", res[3].Server)
assert.Equal(t, discovery.MTStatic, res[3].MatchType)
assert.Equal(t, false, res[3].AssetsSPA)
assert.Equal(t, []string{"192.168.1.0/24", "124.0.0.1"}, res[3].OnlyFromIPs)
assert.Equal(t, "/web2/", res[4].SrcMatch.String())
assert.Equal(t, "/var/web2", res[4].Dst)
@ -139,4 +143,5 @@ func TestFile_List(t *testing.T) {
assert.Equal(t, "*", res[4].Server)
assert.Equal(t, discovery.MTStatic, res[4].MatchType)
assert.Equal(t, true, res[4].AssetsSPA)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
}

View File

@ -9,7 +9,7 @@ import (
"github.com/umputun/reproxy/app/discovery"
)
// Static provider, rules are server,from,to
// Static provider, rules are server,source_url,destination[,ping]
type Static struct {
Rules []string // each rule is 4 elements comma separated - server,source_url,destination,ping
}

View File

@ -1,7 +1,7 @@
default:
- {route: "^/api/svc1/(.*)", dest: "http://127.0.0.1:8080/blah1/$1"}
- {route: "/api/svc3/xyz", dest: "http://127.0.0.3:8080/blah3/xyz", "ping": "http://127.0.0.3:8080/ping"}
- {route: "/web/", dest: "/var/web", "assets": yes}
- {route: "/web/", dest: "/var/web", "assets": yes, "remote": "192.168.1.0/24, 124.0.0.1"}
- {route: "/web2/", dest: "/var/web2", "spa": yes}
srv.example.com:
- {route: "^/api/svc2/(.*)", dest: "http://127.0.0.2:8080/blah2/$1/abc"}

View File

@ -29,14 +29,14 @@ import (
)
var opts struct {
Listen string `short:"l" long:"listen" env:"LISTEN" description:"listen on host:port (default: 0.0.0.0:8080/8443 under docker, 127.0.0.1:80/443 without)"`
MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"`
GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"`
ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside ""
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"`
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint
Listen string `short:"l" long:"listen" env:"LISTEN" description:"listen on host:port (default: 0.0.0.0:8080/8443 under docker, 127.0.0.1:80/443 without)"`
MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"`
GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"`
ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside ""
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"`
RemoteLookupHeaders bool `long:"remote-lookup-headers" env:"REMOTE_LOOKUP_HEADERS" description:"enable remote lookup headers"`
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint
SSL struct {
Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` // nolint
@ -273,10 +273,11 @@ func run() error {
ThrottleUser: opts.Throttle.User,
BasicAuthEnabled: len(basicAuthAllowed) > 0,
BasicAuthAllowed: basicAuthAllowed,
OnlyFrom: makeOnlyFromMiddleware(),
}
err = px.Run(ctx)
if err != nil && err == http.ErrServerClosed {
if err != nil && errors.Is(err, http.ErrServerClosed) {
log.Printf("[WARN] proxy server closed, %v", err) // nolint gocritic
return nil
}
@ -424,6 +425,13 @@ func makeLBSelector() func(len int) int {
}
}
func makeOnlyFromMiddleware() *proxy.OnlyFrom {
if opts.RemoteLookupHeaders {
return proxy.NewOnlyFrom(proxy.OFRealIP, proxy.OFForwarded, proxy.OFRemoteAddr)
}
return proxy.NewOnlyFrom(proxy.OFRemoteAddr)
}
func makeErrorReporter() (proxy.Reporter, error) {
result := &proxy.ErrorReporter{
Nice: opts.ErrorReport.Enabled,

View File

@ -244,5 +244,4 @@ func TestHttp_basicAuthHandler(t *testing.T) {
require.Equal(t, http.StatusOK, resp.StatusCode)
})
}
}

149
app/proxy/only_from.go Normal file
View File

@ -0,0 +1,149 @@
package proxy
import (
"bytes"
"net"
"net/http"
"strings"
"github.com/umputun/reproxy/app/discovery"
)
// OnlyFrom implements middleware to allow access for a limited list of source IPs.
type OnlyFrom struct {
lookups []OFLookup
}
// OFLookup defines lookup method for source IP.
type OFLookup string
// enum of possible lookup methods
const (
OFRemoteAddr OFLookup = "remote-addr"
OFRealIP OFLookup = "real-ip"
OFForwarded OFLookup = "forwarded"
)
// NewOnlyFrom creates OnlyFrom middleware with given lookup methods.
func NewOnlyFrom(lookups ...OFLookup) *OnlyFrom {
return &OnlyFrom{lookups: lookups}
}
// Handler implements middleware interface.
func (o *OnlyFrom) Handler(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
var allowedIPs []string
reqCtx := r.Context()
if reqCtx.Value(ctxMatch) != nil { // route match detected by matchHandler
match := reqCtx.Value(ctxMatch).(discovery.MatchedRoute)
allowedIPs = match.Mapper.OnlyFromIPs
}
if len(allowedIPs) == 0 {
// no restrictions if no ips defined
next.ServeHTTP(w, r)
return
}
realIP := o.realIP(o.lookups, r)
if realIP != "" && o.matchRemoteIP(realIP, allowedIPs) {
next.ServeHTTP(w, r)
return
}
w.WriteHeader(http.StatusForbidden)
}
return http.HandlerFunc(fn)
}
func (o *OnlyFrom) realIP(ipLookups []OFLookup, r *http.Request) string {
realIP := r.Header.Get("X-Real-IP")
forwardedFor := r.Header.Get("X-Forwarded-For")
for _, lookup := range ipLookups {
if lookup == OFRemoteAddr {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr // can't parse, return as is
}
return ip
}
if lookup == OFForwarded && forwardedFor != "" {
// X-Forwarded-For is potentially a list of addresses separated with ","
// The left-most being the original client, and each successive proxy that passed the request
// adding the IP address where it received the request from.
// In case if the original IP is a private behind a proxy, we need to get the first public IP from the list
return preferPublicIP(strings.Split(forwardedFor, ","))
}
if lookup == OFRealIP && realIP != "" {
return realIP
}
}
return "" // we can't get real ip
}
// matchRemoteIP returns true if request's ip matches any of ips in the list of allowedIPs.
// allowedIPs can be defined as IP (like 192.168.1.12) or CIDR (192.168.0.0/16)
func (o *OnlyFrom) matchRemoteIP(remoteIP string, allowedIPs []string) bool {
for _, allowedIP := range allowedIPs {
// check for ip prefix or CIDR
if _, cidrnet, err := net.ParseCIDR(allowedIP); err == nil {
if cidrnet.Contains(net.ParseIP(remoteIP)) {
return true
}
}
// check for ip match
if remoteIP == allowedIP {
return true
}
}
return false
}
// preferPublicIP returns first public IP from the list of IPs
// if no public IP found, returns first IP from the list
func preferPublicIP(ips []string) string {
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if net.ParseIP(ip).IsGlobalUnicast() && !isPrivateSubnet(net.ParseIP(ip)) {
return ip
}
}
return strings.TrimSpace(ips[0])
}
type ipRange struct {
start net.IP
end net.IP
}
var privateRanges = []ipRange{
{start: net.ParseIP("10.0.0.0"), end: net.ParseIP("10.255.255.255")},
{start: net.ParseIP("100.64.0.0"), end: net.ParseIP("100.127.255.255")},
{start: net.ParseIP("172.16.0.0"), end: net.ParseIP("172.31.255.255")},
{start: net.ParseIP("192.0.0.0"), end: net.ParseIP("192.0.0.255")},
{start: net.ParseIP("192.168.0.0"), end: net.ParseIP("192.168.255.255")},
{start: net.ParseIP("198.18.0.0"), end: net.ParseIP("198.19.255.255")},
{start: net.ParseIP("::1"), end: net.ParseIP("::1")},
{start: net.ParseIP("fc00::"), end: net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
{start: net.ParseIP("fe80::"), end: net.ParseIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
}
// isPrivateSubnet - check to see if this ip is in a private subnet
func isPrivateSubnet(ipAddress net.IP) bool {
inRange := func(r ipRange, ipAddress net.IP) bool {
// ensure the IPs are in the same format for comparison
ipAddress = ipAddress.To16()
r.start = r.start.To16()
r.end = r.end.To16()
return bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) <= 0
}
for _, r := range privateRanges {
if inRange(r, ipAddress) {
return true
}
}
return false
}

144
app/proxy/only_from_test.go Normal file
View File

@ -0,0 +1,144 @@
package proxy
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/umputun/reproxy/app/discovery"
)
func TestOnlyFrom_Handler(t *testing.T) {
tbl := []struct {
name string
lookups []OFLookup
allowedIPs []string
remoteAddr string
realIP string
forwardedFor string
expectedStatusCode int
}{
{
name: "allowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "192.168.1.1:1234",
expectedStatusCode: http.StatusOK,
},
{
name: "disallowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "192.168.1.2:1234",
expectedStatusCode: http.StatusForbidden,
},
{
name: "no restrictions",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{},
remoteAddr: "192.168.1.2:1234",
expectedStatusCode: http.StatusOK,
},
{
name: "allowed IP with RealIP lookup",
lookups: []OFLookup{OFRealIP},
allowedIPs: []string{"192.168.1.1"},
realIP: "192.168.1.1",
expectedStatusCode: http.StatusOK,
},
{
name: "disallowed IP with RealIP lookup",
lookups: []OFLookup{OFRealIP},
allowedIPs: []string{"192.168.1.1"},
realIP: "192.168.1.2",
expectedStatusCode: http.StatusForbidden,
},
{
name: "allowed IP with Forwarded lookup",
lookups: []OFLookup{OFForwarded},
allowedIPs: []string{"192.168.1.1"},
forwardedFor: "192.168.1.1",
expectedStatusCode: http.StatusOK,
},
{
name: "allowed IP with Forwarded lookup, mix private and public IPs",
lookups: []OFLookup{OFForwarded},
allowedIPs: []string{"8.8.8.8"},
forwardedFor: "192.168.1.1, 10.0.0.5, 8.8.8.8, 10.10.10.10",
expectedStatusCode: http.StatusOK,
},
{
name: "disallowed IP with Forwarded lookup",
lookups: []OFLookup{OFForwarded},
allowedIPs: []string{"192.168.1.1"},
forwardedFor: "192.168.1.2",
expectedStatusCode: http.StatusForbidden,
},
{
name: "multiple lookups, allowed IP",
lookups: []OFLookup{OFRemoteAddr, OFRealIP},
allowedIPs: []string{"192.168.1.1", "192.168.1.2"},
remoteAddr: "192.168.1.2:1234",
realIP: "192.168.1.1",
expectedStatusCode: http.StatusOK,
},
{
name: "multiple lookups, disallowed IP",
lookups: []OFLookup{OFRemoteAddr, OFRealIP},
allowedIPs: []string{"192.168.1.1", "192.168.1.2"},
remoteAddr: "192.168.1.3:1234",
realIP: "192.168.1.3",
expectedStatusCode: http.StatusForbidden,
},
{
name: "CIDR block, allowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.0/24"},
remoteAddr: "192.168.1.2:1234",
expectedStatusCode: http.StatusOK,
},
{
name: "CIDR block, disallowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.0/24"},
remoteAddr: "192.168.2.2:1234",
expectedStatusCode: http.StatusForbidden,
},
{
name: "invalid remote address format",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "invalid_format",
expectedStatusCode: http.StatusForbidden,
},
{
name: "empty remote address",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "",
expectedStatusCode: http.StatusForbidden,
},
}
for _, tt := range tbl {
t.Run(tt.name, func(t *testing.T) {
onlyFrom := NewOnlyFrom(tt.lookups...)
handler := onlyFrom.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
req := httptest.NewRequest("GET", "http://example.com/foo", http.NoBody)
req.RemoteAddr = tt.remoteAddr
req.Header.Set("X-Real-IP", tt.realIP)
req.Header.Set("X-Forwarded-For", tt.forwardedFor)
req = req.WithContext(context.WithValue(req.Context(),
ctxMatch, discovery.MatchedRoute{Mapper: discovery.URLMapper{OnlyFromIPs: tt.allowedIPs}}))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, tt.expectedStatusCode, rr.Code)
})
}
}

View File

@ -28,27 +28,27 @@ import (
// Http is a proxy server for both http and https
type Http struct { // nolint golint
Matcher
Address string
AssetsLocation string
AssetsWebRoot string
Assets404 string
AssetsSPA bool
MaxBodySize int64
GzEnabled bool
ProxyHeaders []string
DropHeader []string
SSLConfig SSLConfig
Version string
AccessLog io.Writer
StdOutEnabled bool
Signature bool
Timeouts Timeouts
CacheControl MiddlewareProvider
Metrics MiddlewareProvider
PluginConductor MiddlewareProvider
Reporter Reporter
LBSelector func(len int) int
Address string
AssetsLocation string
AssetsWebRoot string
Assets404 string
AssetsSPA bool
MaxBodySize int64
GzEnabled bool
ProxyHeaders []string
DropHeader []string
SSLConfig SSLConfig
Version string
AccessLog io.Writer
StdOutEnabled bool
Signature bool
Timeouts Timeouts
CacheControl MiddlewareProvider
Metrics MiddlewareProvider
PluginConductor MiddlewareProvider
Reporter Reporter
LBSelector func(len int) int
OnlyFrom *OnlyFrom
BasicAuthEnabled bool
BasicAuthAllowed []string
@ -121,18 +121,19 @@ func (h *Http) Run(ctx context.Context) error {
}()
handler := R.Wrap(h.proxyHandler(),
R.Recoverer(log.Default()), // recover on errors
signatureHandler(h.Signature, h.Version), // send app signature
h.pingHandler, // respond to /ping
R.Recoverer(log.Default()), // recover on errors
signatureHandler(h.Signature, h.Version), // send app signature
h.OnlyFrom.Handler, // limit source (remote) IPs if defined
h.pingHandler, // respond to /ping
basicAuthHandler(h.BasicAuthEnabled, h.BasicAuthAllowed), // basic auth
h.healthMiddleware, // respond to /health
h.matchHandler, // set matched routes to context
limiterSystemHandler(h.ThrottleSystem), // limit total requests/sec
limiterUserHandler(h.ThrottleUser), // req/seq per user/route match
h.mgmtHandler(), // handles /metrics and /routes for prometheus
h.pluginHandler(), // prc to external plugins
headersHandler(h.ProxyHeaders, h.DropHeader), // add response headers and delete some request headers
accessLogHandler(h.AccessLog), // apache-format log file
h.healthMiddleware, // respond to /health
h.matchHandler, // set matched routes to context
limiterSystemHandler(h.ThrottleSystem), // limit total requests/sec
limiterUserHandler(h.ThrottleUser), // req/seq per user/route match
h.mgmtHandler(), // handles /metrics and /routes for prometheus
h.pluginHandler(), // prc to external plugins
headersHandler(h.ProxyHeaders, h.DropHeader), // add response headers and delete some request headers
accessLogHandler(h.AccessLog), // apache-format log file
stdoutLogHandler(h.StdOutEnabled, logger.New(logger.Log(log.Default()), logger.Prefix("[INFO]")).Handler),
maxReqSizeHandler(h.MaxBodySize), // limit request max size
gzipHandler(h.GzEnabled), // gzip response
@ -400,22 +401,22 @@ func (h *Http) makeHTTPServer(addr string, router http.Handler) *http.Server {
}
func (h *Http) setXRealIP(r *http.Request) {
remoteIP := r.Header.Get("X-Forwarded-For")
if remoteIP == "" {
remoteIP = r.RemoteAddr
}
ip, _, err := net.SplitHostPort(remoteIP)
if err != nil {
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
// use the left-most non-private client IP address
// if there is no any non-private IP address, use the left-most address
r.Header.Set("X-Real-IP", preferPublicIP(strings.Split(forwarded, ",")))
return
}
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return
}
userIP := net.ParseIP(ip)
if userIP == nil {
return
}
r.Header.Add("X-Real-IP", ip)
r.Header.Set("X-Real-IP", ip)
}
// discoveredServers gets the list of servers discovered by providers.

View File

@ -34,6 +34,7 @@ func TestHttp_Do(t *testing.T) {
t.Logf("req: %v", r)
w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
require.Equal(t, "127.0.0.1", r.Header.Get("X-Forwarded-For"))
fmt.Fprintf(w, "response %s", r.URL.String())
}))
@ -59,7 +60,7 @@ func TestHttp_Do(t *testing.T) {
client := http.Client{}
{
t.Run("to 127.0.0.1, good", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
@ -75,9 +76,9 @@ func TestHttp_Do(t *testing.T) {
assert.Equal(t, "v1", resp.Header.Get("h1"))
assert.Equal(t, "vv1", resp.Header.Get("hh1"))
assert.Equal(t, "vv2", resp.Header.Get("hh2"))
}
})
{
t.Run("to localhost, good", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/something")
require.NoError(t, err)
defer resp.Body.Close()
@ -89,9 +90,9 @@ func TestHttp_Do(t *testing.T) {
assert.Equal(t, "response /123/something", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
})
{
t.Run("bad gateway", func(t *testing.T) {
resp, err := client.Get("http://127.0.0.1:" + strconv.Itoa(port) + "/bad/something")
require.NoError(t, err)
defer resp.Body.Close()
@ -100,9 +101,9 @@ func TestHttp_Do(t *testing.T) {
require.NoError(t, err)
assert.Contains(t, string(b), "Sorry for the inconvenience")
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
}
})
{
t.Run("url encode", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21")
require.NoError(t, err)
defer resp.Body.Close()
@ -114,7 +115,7 @@ func TestHttp_Do(t *testing.T) {
assert.Equal(t, "response /123/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
})
}
func TestHttp_DoWithAssets(t *testing.T) {
@ -153,7 +154,7 @@ func TestHttp_DoWithAssets(t *testing.T) {
client := http.Client{}
{
t.Run("api call", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
@ -167,9 +168,9 @@ func TestHttp_DoWithAssets(t *testing.T) {
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
})
{
t.Run("static call, good", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err)
defer resp.Body.Close()
@ -182,9 +183,9 @@ func TestHttp_DoWithAssets(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
})
{
t.Run("static call, bad", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err)
defer resp.Body.Close()
@ -192,9 +193,9 @@ func TestHttp_DoWithAssets(t *testing.T) {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "404 page not found\n", string(body))
}
})
{
t.Run("bad url", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad")
require.NoError(t, err)
defer resp.Body.Close()
@ -203,7 +204,7 @@ func TestHttp_DoWithAssets(t *testing.T) {
require.NoError(t, err)
assert.Contains(t, string(body), "Server error")
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
}
})
}
func TestHttp_DoWithAssetsCustom404(t *testing.T) {
@ -243,7 +244,7 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
client := http.Client{}
{
t.Run("api call, found", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
@ -257,9 +258,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
})
{
t.Run("static call, found", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err)
defer resp.Body.Close()
@ -272,9 +273,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
})
{
t.Run("static call, not found", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err)
defer resp.Body.Close()
@ -284,9 +285,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body))
t.Logf("%+v", resp.Header)
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
}
})
{
t.Run("another static call, not found", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad2.html")
require.NoError(t, err)
defer resp.Body.Close()
@ -296,7 +297,7 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body))
t.Logf("%+v", resp.Header)
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
}
})
}
func TestHttp_DoWithSpaAssets(t *testing.T) {
@ -336,7 +337,7 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
client := http.Client{}
{
t.Run("api call, good", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
@ -350,9 +351,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1"))
}
})
{
t.Run("static call, good", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err)
defer resp.Body.Close()
@ -365,9 +366,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
})
{
t.Run("static call, not found server index", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err)
defer resp.Body.Close()
@ -380,9 +381,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
}
})
{
t.Run("static call, bad url", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad")
require.NoError(t, err)
defer resp.Body.Close()
@ -391,7 +392,7 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
require.NoError(t, err)
assert.Contains(t, string(body), "Server error")
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
}
})
}
func TestHttp_DoWithAssetRules(t *testing.T) {
@ -715,16 +716,16 @@ func TestHttp_withBasicAuth(t *testing.T) {
client := http.Client{}
{
t.Run("no auth", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
})
{
t.Run("bad auth", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test", "badpasswd")
require.NoError(t, err)
@ -732,8 +733,9 @@ func TestHttp_withBasicAuth(t *testing.T) {
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
{
})
t.Run("good auth", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test", "passwd")
require.NoError(t, err)
@ -741,8 +743,9 @@ func TestHttp_withBasicAuth(t *testing.T) {
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
{
})
t.Run("good auth 2", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test2", "passwd2")
require.NoError(t, err)
@ -750,7 +753,7 @@ func TestHttp_withBasicAuth(t *testing.T) {
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
})
}
func TestHttp_toHttp(t *testing.T) {
@ -766,9 +769,9 @@ func TestHttp_toHttp(t *testing.T) {
}
h := Http{}
for i, tt := range tbl {
for _, tt := range tbl {
tt := tt
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Run(tt.addr, func(t *testing.T) {
assert.Equal(t, tt.res, h.toHTTP(tt.addr, tt.port))
})
}
@ -791,8 +794,8 @@ func TestHttp_isAssetRequest(t *testing.T) {
{"/static/", "/tmp", "", false},
}
for i, tt := range tbl {
t.Run(strconv.Itoa(i), func(t *testing.T) {
for _, tt := range tbl {
t.Run(tt.req, func(t *testing.T) {
h := Http{AssetsLocation: tt.assetsLocation, AssetsWebRoot: tt.assetsWebRoot}
r, err := http.NewRequest("GET", tt.req, http.NoBody)
require.NoError(t, err)
@ -803,56 +806,61 @@ func TestHttp_isAssetRequest(t *testing.T) {
}
func TestHttp_matchHandler(t *testing.T) {
tbl := []struct {
name string
matches discovery.Matches
res string
ok bool
}{
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
name: "all alive destinations",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: true},
{Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: true},
}},
"dest1", true,
res: "dest1", ok: true,
},
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
name: "second alive destination",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: false},
}},
"dest2", true,
res: "dest2", ok: true,
},
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
name: "one dead destination",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: true},
}},
"dest2", true,
res: "dest2", ok: true,
},
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
name: "last alive destination",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: false},
{Destination: "dest3", Alive: true},
}},
"dest3", true,
res: "dest3", ok: true,
},
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
name: "all dead destinations",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: false},
{Destination: "dest3", Alive: false},
}},
"", false,
res: "", ok: false,
},
{
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{}}, "", false,
name: "no destinations",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{}}, res: "", ok: false,
},
}
@ -864,9 +872,8 @@ func TestHttp_matchHandler(t *testing.T) {
}
client := http.Client{}
for i, tt := range tbl {
t.Run(strconv.Itoa(i), func(t *testing.T) {
for _, tt := range tbl {
t.Run(tt.name, func(t *testing.T) {
h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }}
handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %+v", r)
@ -893,7 +900,6 @@ func TestHttp_matchHandler(t *testing.T) {
}
func TestHttp_discoveredServers(t *testing.T) {
calls := 0
m := &MatcherMock{ServersFunc: func() []string {
defer func() { calls++ }()