1
0
mirror of https://github.com/umputun/reproxy.git synced 2025-02-16 18:34:30 +02:00

add middleware to optionally allow requests from giving ips/ranges

add new remote param to docker and file providers

lint: http nil body

add support of remote ips to consul provider

local implementation of onlyfrom middleware

lint: missing comment

make proxy tests more readable

preffer public IP if any forwwarded
This commit is contained in:
Umputun 2023-11-25 13:54:14 -06:00
parent 36a5443378
commit a896f08eec
16 changed files with 531 additions and 142 deletions

View File

@ -37,6 +37,7 @@ type URLMapper struct {
PingURL string PingURL string
MatchType MatchType MatchType MatchType
RedirectType RedirectType RedirectType RedirectType
OnlyFromIPs []string
AssetsLocation string // local FS root location AssetsLocation string // local FS root location
AssetsWebRoot string // web root location AssetsWebRoot string // web root location
@ -484,16 +485,6 @@ func (s *Service) mergeEvents(ctx context.Context, chs ...<-chan ProviderID) <-c
return out return out
} }
// Contains checks if the input string (e) in the given slice
func Contains(e string, s []string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
// IsAlive indicates whether mapper destination is alive // IsAlive indicates whether mapper destination is alive
func (m URLMapper) IsAlive() bool { func (m URLMapper) IsAlive() bool {
return !m.dead return !m.dead
@ -515,3 +506,24 @@ func (m URLMapper) ping() (string, error) {
return "", err return "", err
} }
// Contains checks if the input string (e) in the given slice
func Contains(e string, s []string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
// ParseOnlyFrom parses comma separated list of IPs
func ParseOnlyFrom(s string) (res []string) {
if s == "" {
return []string{}
}
for _, v := range strings.Split(s, ",") {
res = append(res, strings.TrimSpace(v))
}
return res
}

View File

@ -39,7 +39,7 @@ func TestService_Run(t *testing.T) {
ListFunc: func() ([]URLMapper, error) { ListFunc: func() ([]URLMapper, error) {
return []URLMapper{ return []URLMapper{
{Server: "localhost", SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), {Server: "localhost", SrcMatch: *regexp.MustCompile("/api/svc3/xyz"),
Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker}, Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker, OnlyFromIPs: []string{"127.0.0.1"}},
}, nil }, nil
}, },
} }
@ -66,6 +66,7 @@ func TestService_Run(t *testing.T) {
assert.Equal(t, "localhost", mappers[0].Server) assert.Equal(t, "localhost", mappers[0].Server)
assert.Equal(t, "/api/svc3/xyz", mappers[0].SrcMatch.String()) assert.Equal(t, "/api/svc3/xyz", mappers[0].SrcMatch.String())
assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", mappers[0].Dst) assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", mappers[0].Dst)
assert.Equal(t, []string{"127.0.0.1"}, mappers[0].OnlyFromIPs)
assert.Equal(t, 1, len(p1.EventsCalls())) assert.Equal(t, 1, len(p1.EventsCalls()))
assert.Equal(t, 1, len(p2.EventsCalls())) assert.Equal(t, 1, len(p2.EventsCalls()))
@ -104,7 +105,8 @@ func TestService_Match(t *testing.T) {
}, },
ListFunc: func() ([]URLMapper, error) { ListFunc: func() ([]URLMapper, error) {
return []URLMapper{ return []URLMapper{
{SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker}, {SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), Dst: "http://127.0.0.3:8080/blah3/xyz",
OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}, ProviderID: PIDocker},
{SrcMatch: *regexp.MustCompile("/web"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic, {SrcMatch: *regexp.MustCompile("/web"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic,
AssetsWebRoot: "/web", AssetsLocation: "/var/web"}, AssetsWebRoot: "/web", AssetsLocation: "/var/web"},
{SrcMatch: *regexp.MustCompile("/www/"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic, {SrcMatch: *regexp.MustCompile("/www/"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic,
@ -131,9 +133,11 @@ func TestService_Match(t *testing.T) {
res Matches res Matches
}{ }{
{"example.com", "/api/svc3/xyz/something", Matches{MTProxy, []MatchedRoute{ {"example.com", "/api/svc3/xyz/something", Matches{MTProxy, []MatchedRoute{
{Destination: "http://127.0.0.3:8080/blah3/xyz/something", Alive: true}}}}, {Destination: "http://127.0.0.3:8080/blah3/xyz/something", Alive: true,
Mapper: URLMapper{OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}}}}}},
{"example.com", "/api/svc3/xyz", Matches{MTProxy, []MatchedRoute{{ {"example.com", "/api/svc3/xyz", Matches{MTProxy, []MatchedRoute{{
Destination: "http://127.0.0.3:8080/blah3/xyz", Alive: true}}}}, Destination: "http://127.0.0.3:8080/blah3/xyz", Alive: true,
Mapper: URLMapper{OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}}}}}},
{"abc.example.com", "/api/svc1/1234", Matches{MTProxy, []MatchedRoute{ {"abc.example.com", "/api/svc1/1234", Matches{MTProxy, []MatchedRoute{
{Destination: "http://127.0.0.1:8080/blah1/1234", Alive: true}}}}, {Destination: "http://127.0.0.1:8080/blah1/1234", Alive: true}}}},
{"zzz.example.com", "/aaa/api/svc1/1234", Matches{MTProxy, nil}}, {"zzz.example.com", "/aaa/api/svc1/1234", Matches{MTProxy, nil}},
@ -167,6 +171,7 @@ func TestService_Match(t *testing.T) {
for i := 0; i < len(res.Routes); i++ { for i := 0; i < len(res.Routes); i++ {
assert.Equal(t, tt.res.Routes[i].Alive, res.Routes[i].Alive) assert.Equal(t, tt.res.Routes[i].Alive, res.Routes[i].Alive)
assert.Equal(t, tt.res.Routes[i].Destination, res.Routes[i].Destination) assert.Equal(t, tt.res.Routes[i].Destination, res.Routes[i].Destination)
assert.Equal(t, tt.res.Routes[i].Mapper.OnlyFromIPs, res.Routes[i].Mapper.OnlyFromIPs)
} }
assert.Equal(t, tt.res.MatchType, res.MatchType) assert.Equal(t, tt.res.MatchType, res.MatchType)
}) })
@ -608,3 +613,39 @@ func TestCheckHealth(t *testing.T) {
assert.NoError(t, res[ts.URL]) assert.NoError(t, res[ts.URL])
assert.NoError(t, res[ts2.URL]) assert.NoError(t, res[ts2.URL])
} }
func TestParseOnlyFrom(t *testing.T) {
tbl := []struct {
name string
input string
expected []string
}{
{
name: "empty string",
input: "",
expected: []string{},
},
{
name: "single IP",
input: "192.168.1.1",
expected: []string{"192.168.1.1"},
},
{
name: "multiple IPs",
input: "192.168.1.1, 192.168.1.2, 192.168.1.3, 10.0.0.0/16",
expected: []string{"192.168.1.1", "192.168.1.2", "192.168.1.3", "10.0.0.0/16"},
},
{
name: "multiple IPs with extra spaces",
input: " 192.168.1.1 , 192.168.1.2 , 192.168.1.3 ",
expected: []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"},
},
}
for _, tt := range tbl {
t.Run(tt.name, func(t *testing.T) {
result := ParseOnlyFrom(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@ -3,12 +3,13 @@ package consulcatalog
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/umputun/reproxy/app/discovery"
"log" "log"
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
"time" "time"
"github.com/umputun/reproxy/app/discovery"
) )
//go:generate moq -out consul_client_mock.go -skip-ensure -fmt goimports . ConsulClient //go:generate moq -out consul_client_mock.go -skip-ensure -fmt goimports . ConsulClient
@ -139,6 +140,7 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) {
destURL := fmt.Sprintf("http://%s:%d/$1", c.ServiceAddress, c.ServicePort) destURL := fmt.Sprintf("http://%s:%d/$1", c.ServiceAddress, c.ServicePort)
pingURL := fmt.Sprintf("http://%s:%d/ping", c.ServiceAddress, c.ServicePort) pingURL := fmt.Sprintf("http://%s:%d/ping", c.ServiceAddress, c.ServicePort)
server := "*" server := "*"
onlyFrom := []string{}
if v, ok := c.Labels["reproxy.enabled"]; ok && (v == "true" || v == "yes" || v == "1") { if v, ok := c.Labels["reproxy.enabled"]; ok && (v == "true" || v == "yes" || v == "1") {
enabled = true enabled = true
@ -159,6 +161,10 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) {
server = v server = v
} }
if v, ok := c.Labels["reproxy.remote"]; ok {
onlyFrom = discovery.ParseOnlyFrom(v)
}
if v, ok := c.Labels["reproxy.ping"]; ok { if v, ok := c.Labels["reproxy.ping"]; ok {
enabled = true enabled = true
pingURL = fmt.Sprintf("http://%s:%d%s", c.ServiceAddress, c.ServicePort, v) pingURL = fmt.Sprintf("http://%s:%d%s", c.ServiceAddress, c.ServicePort, v)
@ -177,7 +183,7 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) {
// server label may have multiple, comma separated servers // server label may have multiple, comma separated servers
for _, srv := range strings.Split(server, ",") { for _, srv := range strings.Split(server, ",") {
res = append(res, discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL, res = append(res, discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL,
PingURL: pingURL, ProviderID: discovery.PIConsulCatalog}) OnlyFromIPs: onlyFrom, PingURL: pingURL, ProviderID: discovery.PIConsulCatalog})
} }
} }

View File

@ -3,12 +3,14 @@ package consulcatalog
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/umputun/reproxy/app/discovery"
"sort" "sort"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/umputun/reproxy/app/discovery"
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
@ -62,7 +64,8 @@ func TestConsulCatalog_List(t *testing.T) {
ServiceAddress: "addr3", ServiceAddress: "addr3",
ServicePort: 3000, ServicePort: 3000,
Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1", Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1",
"reproxy.server": "example.com,domain.com", "reproxy.ping": "/ping", "reproxy.enabled": "yes"}, "reproxy.server": "example.com,domain.com", "reproxy.ping": "/ping",
"reproxy.enabled": "yes", "reproxy.remote": "127.0.0.1, 192.168.1.0/24"},
}, },
{ {
ServiceID: "id4", ServiceID: "id4",
@ -91,21 +94,25 @@ func TestConsulCatalog_List(t *testing.T) {
assert.Equal(t, "http://addr3:3000/blah/$1", res[0].Dst) assert.Equal(t, "http://addr3:3000/blah/$1", res[0].Dst)
assert.Equal(t, "example.com", res[0].Server) assert.Equal(t, "example.com", res[0].Server)
assert.Equal(t, "http://addr3:3000/ping", res[0].PingURL) assert.Equal(t, "http://addr3:3000/ping", res[0].PingURL)
assert.Equal(t, []string{"127.0.0.1", "192.168.1.0/24"}, res[0].OnlyFromIPs)
assert.Equal(t, "^/api/123/(.*)", res[1].SrcMatch.String()) assert.Equal(t, "^/api/123/(.*)", res[1].SrcMatch.String())
assert.Equal(t, "http://addr3:3000/blah/$1", res[1].Dst) assert.Equal(t, "http://addr3:3000/blah/$1", res[1].Dst)
assert.Equal(t, "domain.com", res[1].Server) assert.Equal(t, "domain.com", res[1].Server)
assert.Equal(t, "http://addr3:3000/ping", res[1].PingURL) assert.Equal(t, "http://addr3:3000/ping", res[1].PingURL)
assert.Equal(t, []string{"127.0.0.1", "192.168.1.0/24"}, res[1].OnlyFromIPs)
assert.Equal(t, "^/(.*)", res[2].SrcMatch.String()) assert.Equal(t, "^/(.*)", res[2].SrcMatch.String())
assert.Equal(t, "http://addr44:4000/$1", res[2].Dst) assert.Equal(t, "http://addr44:4000/$1", res[2].Dst)
assert.Equal(t, "http://addr44:4000/ping", res[2].PingURL) assert.Equal(t, "http://addr44:4000/ping", res[2].PingURL)
assert.Equal(t, "*", res[2].Server) assert.Equal(t, "*", res[2].Server)
assert.Equal(t, []string{}, res[2].OnlyFromIPs)
assert.Equal(t, "^/(.*)", res[3].SrcMatch.String()) assert.Equal(t, "^/(.*)", res[3].SrcMatch.String())
assert.Equal(t, "http://addr2:2000/$1", res[3].Dst) assert.Equal(t, "http://addr2:2000/$1", res[3].Dst)
assert.Equal(t, "http://addr2:2000/ping", res[3].PingURL) assert.Equal(t, "http://addr2:2000/ping", res[3].PingURL)
assert.Equal(t, "*", res[3].Server) assert.Equal(t, "*", res[3].Server)
assert.Equal(t, []string{}, res[3].OnlyFromIPs)
} }
func TestConsulCatalog_serviceListWasChanged(t *testing.T) { func TestConsulCatalog_serviceListWasChanged(t *testing.T) {

View File

@ -103,6 +103,7 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper)
// defaults // defaults
destURL, pingURL, server := fmt.Sprintf("http://%s:%d/$1", c.IP, port), fmt.Sprintf("http://%s:%d/ping", c.IP, port), "*" destURL, pingURL, server := fmt.Sprintf("http://%s:%d/$1", c.IP, port), fmt.Sprintf("http://%s:%d/ping", c.IP, port), "*"
assetsWebRoot, assetsLocation, assetsSPA := "", "", false assetsWebRoot, assetsLocation, assetsSPA := "", "", false
onlyFrom := []string{}
if d.AutoAPI && n == 0 { if d.AutoAPI && n == 0 {
enabled = true enabled = true
@ -133,6 +134,10 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper)
server = v server = v
} }
if v, ok := d.labelN(c.Labels, n, "remote"); ok {
onlyFrom = discovery.ParseOnlyFrom(v)
}
if v, ok := d.labelN(c.Labels, n, "ping"); ok { if v, ok := d.labelN(c.Labels, n, "ping"); ok {
enabled = true enabled = true
if strings.HasPrefix(v, "http://") || strings.HasPrefix(v, "https://") { if strings.HasPrefix(v, "http://") || strings.HasPrefix(v, "https://") {
@ -171,7 +176,7 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper)
// docker server label may have multiple, comma separated servers // docker server label may have multiple, comma separated servers
for _, srv := range strings.Split(server, ",") { for _, srv := range strings.Split(server, ",") {
mp := discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL, mp := discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL,
PingURL: pingURL, ProviderID: discovery.PIDocker, MatchType: discovery.MTProxy} PingURL: pingURL, OnlyFromIPs: onlyFrom, ProviderID: discovery.PIDocker, MatchType: discovery.MTProxy}
// for assets we add the second proxy mapping only if explicitly requested // for assets we add the second proxy mapping only if explicitly requested
if assetsWebRoot != "" && explicit { if assetsWebRoot != "" && explicit {

View File

@ -30,7 +30,7 @@ func TestDocker_List(t *testing.T) {
{ {
Name: "c1", State: "running", IP: "127.0.0.2", Ports: []int{12345}, Name: "c1", State: "running", IP: "127.0.0.2", Ports: []int{12345},
Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1", Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1",
"reproxy.server": "example.com", "reproxy.ping": "/ping"}, "reproxy.server": "example.com", "reproxy.ping": "/ping", "reproxy.remote": "192.168.1.0/24, 127.0.0.1"},
}, },
{ {
Name: "c1", State: "running", IP: "127.0.0.21", Ports: []int{12345}, Name: "c1", State: "running", IP: "127.0.0.21", Ports: []int{12345},
@ -64,21 +64,25 @@ func TestDocker_List(t *testing.T) {
assert.Equal(t, "http://127.0.0.2:12345/blah/$1", res[0].Dst) assert.Equal(t, "http://127.0.0.2:12345/blah/$1", res[0].Dst)
assert.Equal(t, "example.com", res[0].Server) assert.Equal(t, "example.com", res[0].Server)
assert.Equal(t, "http://127.0.0.2:12345/ping", res[0].PingURL) assert.Equal(t, "http://127.0.0.2:12345/ping", res[0].PingURL)
assert.Equal(t, []string{"192.168.1.0/24", "127.0.0.1"}, res[0].OnlyFromIPs)
assert.Equal(t, "^/api/90/(.*)", res[1].SrcMatch.String()) assert.Equal(t, "^/api/90/(.*)", res[1].SrcMatch.String())
assert.Equal(t, "http://example.com/blah/$1", res[1].Dst) assert.Equal(t, "http://example.com/blah/$1", res[1].Dst)
assert.Equal(t, "https://example.com//ping", res[1].PingURL) assert.Equal(t, "https://example.com//ping", res[1].PingURL)
assert.Equal(t, "example.com", res[1].Server) assert.Equal(t, "example.com", res[1].Server)
assert.Equal(t, []string{}, res[1].OnlyFromIPs)
assert.Equal(t, "^/c2/(.*)", res[2].SrcMatch.String()) assert.Equal(t, "^/c2/(.*)", res[2].SrcMatch.String())
assert.Equal(t, "http://127.0.0.3:12346/$1", res[2].Dst) assert.Equal(t, "http://127.0.0.3:12346/$1", res[2].Dst)
assert.Equal(t, "http://127.0.0.3:12346/ping", res[2].PingURL) assert.Equal(t, "http://127.0.0.3:12346/ping", res[2].PingURL)
assert.Equal(t, "*", res[2].Server) assert.Equal(t, "*", res[2].Server)
assert.Equal(t, []string{}, res[2].OnlyFromIPs)
assert.Equal(t, "^/a/(.*)", res[3].SrcMatch.String()) assert.Equal(t, "^/a/(.*)", res[3].SrcMatch.String())
assert.Equal(t, "http://127.0.0.2:12348/a/$1", res[3].Dst) assert.Equal(t, "http://127.0.0.2:12348/a/$1", res[3].Dst)
assert.Equal(t, "http://127.0.0.2:12348/ping", res[3].PingURL) assert.Equal(t, "http://127.0.0.2:12348/ping", res[3].PingURL)
assert.Equal(t, "example.com", res[3].Server) assert.Equal(t, "example.com", res[3].Server)
assert.Equal(t, []string{}, res[3].OnlyFromIPs)
} }
func TestDocker_ListMulti(t *testing.T) { func TestDocker_ListMulti(t *testing.T) {

View File

@ -84,6 +84,7 @@ func (d *File) List() (res []discovery.URLMapper, err error) {
Ping string `yaml:"ping"` Ping string `yaml:"ping"`
AssetsEnabled bool `yaml:"assets"` AssetsEnabled bool `yaml:"assets"`
AssetsSPA bool `yaml:"spa"` AssetsSPA bool `yaml:"spa"`
OnlyFrom string `yaml:"remote"`
} }
fh, err := os.Open(d.FileName) fh, err := os.Open(d.FileName)
if err != nil { if err != nil {
@ -106,12 +107,13 @@ func (d *File) List() (res []discovery.URLMapper, err error) {
srv = "*" srv = "*"
} }
mapper := discovery.URLMapper{ mapper := discovery.URLMapper{
Server: srv, Server: srv,
SrcMatch: *rx, SrcMatch: *rx,
Dst: f.Dest, Dst: f.Dest,
PingURL: f.Ping, PingURL: f.Ping,
ProviderID: discovery.PIFile, ProviderID: discovery.PIFile,
MatchType: discovery.MTProxy, MatchType: discovery.MTProxy,
OnlyFromIPs: discovery.ParseOnlyFrom(f.OnlyFrom),
} }
if f.AssetsEnabled || f.AssetsSPA { if f.AssetsEnabled || f.AssetsSPA {
mapper.MatchType = discovery.MTStatic mapper.MatchType = discovery.MTStatic

View File

@ -113,18 +113,21 @@ func TestFile_List(t *testing.T) {
assert.Equal(t, "", res[0].PingURL) assert.Equal(t, "", res[0].PingURL)
assert.Equal(t, "srv.example.com", res[0].Server) assert.Equal(t, "srv.example.com", res[0].Server)
assert.Equal(t, discovery.MTProxy, res[0].MatchType) assert.Equal(t, discovery.MTProxy, res[0].MatchType)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
assert.Equal(t, "^/api/svc1/(.*)", res[1].SrcMatch.String()) assert.Equal(t, "^/api/svc1/(.*)", res[1].SrcMatch.String())
assert.Equal(t, "http://127.0.0.1:8080/blah1/$1", res[1].Dst) assert.Equal(t, "http://127.0.0.1:8080/blah1/$1", res[1].Dst)
assert.Equal(t, "", res[1].PingURL) assert.Equal(t, "", res[1].PingURL)
assert.Equal(t, "*", res[1].Server) assert.Equal(t, "*", res[1].Server)
assert.Equal(t, discovery.MTProxy, res[1].MatchType) assert.Equal(t, discovery.MTProxy, res[1].MatchType)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
assert.Equal(t, "/api/svc3/xyz", res[2].SrcMatch.String()) assert.Equal(t, "/api/svc3/xyz", res[2].SrcMatch.String())
assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", res[2].Dst) assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", res[2].Dst)
assert.Equal(t, "http://127.0.0.3:8080/ping", res[2].PingURL) assert.Equal(t, "http://127.0.0.3:8080/ping", res[2].PingURL)
assert.Equal(t, "*", res[2].Server) assert.Equal(t, "*", res[2].Server)
assert.Equal(t, discovery.MTProxy, res[2].MatchType) assert.Equal(t, discovery.MTProxy, res[2].MatchType)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
assert.Equal(t, "/web/", res[3].SrcMatch.String()) assert.Equal(t, "/web/", res[3].SrcMatch.String())
assert.Equal(t, "/var/web", res[3].Dst) assert.Equal(t, "/var/web", res[3].Dst)
@ -132,6 +135,7 @@ func TestFile_List(t *testing.T) {
assert.Equal(t, "*", res[3].Server) assert.Equal(t, "*", res[3].Server)
assert.Equal(t, discovery.MTStatic, res[3].MatchType) assert.Equal(t, discovery.MTStatic, res[3].MatchType)
assert.Equal(t, false, res[3].AssetsSPA) assert.Equal(t, false, res[3].AssetsSPA)
assert.Equal(t, []string{"192.168.1.0/24", "124.0.0.1"}, res[3].OnlyFromIPs)
assert.Equal(t, "/web2/", res[4].SrcMatch.String()) assert.Equal(t, "/web2/", res[4].SrcMatch.String())
assert.Equal(t, "/var/web2", res[4].Dst) assert.Equal(t, "/var/web2", res[4].Dst)
@ -139,4 +143,5 @@ func TestFile_List(t *testing.T) {
assert.Equal(t, "*", res[4].Server) assert.Equal(t, "*", res[4].Server)
assert.Equal(t, discovery.MTStatic, res[4].MatchType) assert.Equal(t, discovery.MTStatic, res[4].MatchType)
assert.Equal(t, true, res[4].AssetsSPA) assert.Equal(t, true, res[4].AssetsSPA)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
} }

View File

@ -9,7 +9,7 @@ import (
"github.com/umputun/reproxy/app/discovery" "github.com/umputun/reproxy/app/discovery"
) )
// Static provider, rules are server,from,to // Static provider, rules are server,source_url,destination[,ping]
type Static struct { type Static struct {
Rules []string // each rule is 4 elements comma separated - server,source_url,destination,ping Rules []string // each rule is 4 elements comma separated - server,source_url,destination,ping
} }

View File

@ -1,7 +1,7 @@
default: default:
- {route: "^/api/svc1/(.*)", dest: "http://127.0.0.1:8080/blah1/$1"} - {route: "^/api/svc1/(.*)", dest: "http://127.0.0.1:8080/blah1/$1"}
- {route: "/api/svc3/xyz", dest: "http://127.0.0.3:8080/blah3/xyz", "ping": "http://127.0.0.3:8080/ping"} - {route: "/api/svc3/xyz", dest: "http://127.0.0.3:8080/blah3/xyz", "ping": "http://127.0.0.3:8080/ping"}
- {route: "/web/", dest: "/var/web", "assets": yes} - {route: "/web/", dest: "/var/web", "assets": yes, "remote": "192.168.1.0/24, 124.0.0.1"}
- {route: "/web2/", dest: "/var/web2", "spa": yes} - {route: "/web2/", dest: "/var/web2", "spa": yes}
srv.example.com: srv.example.com:
- {route: "^/api/svc2/(.*)", dest: "http://127.0.0.2:8080/blah2/$1/abc"} - {route: "^/api/svc2/(.*)", dest: "http://127.0.0.2:8080/blah2/$1/abc"}

View File

@ -29,14 +29,14 @@ import (
) )
var opts struct { var opts struct {
Listen string `short:"l" long:"listen" env:"LISTEN" description:"listen on host:port (default: 0.0.0.0:8080/8443 under docker, 127.0.0.1:80/443 without)"` Listen string `short:"l" long:"listen" env:"LISTEN" description:"listen on host:port (default: 0.0.0.0:8080/8443 under docker, 127.0.0.1:80/443 without)"`
MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"` MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"`
GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"` GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"`
ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside "" ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside ""
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","` DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"` AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"`
RemoteLookupHeaders bool `long:"remote-lookup-headers" env:"REMOTE_LOOKUP_HEADERS" description:"enable remote lookup headers"`
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint
SSL struct { SSL struct {
Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` // nolint Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` // nolint
@ -273,10 +273,11 @@ func run() error {
ThrottleUser: opts.Throttle.User, ThrottleUser: opts.Throttle.User,
BasicAuthEnabled: len(basicAuthAllowed) > 0, BasicAuthEnabled: len(basicAuthAllowed) > 0,
BasicAuthAllowed: basicAuthAllowed, BasicAuthAllowed: basicAuthAllowed,
OnlyFrom: makeOnlyFromMiddleware(),
} }
err = px.Run(ctx) err = px.Run(ctx)
if err != nil && err == http.ErrServerClosed { if err != nil && errors.Is(err, http.ErrServerClosed) {
log.Printf("[WARN] proxy server closed, %v", err) // nolint gocritic log.Printf("[WARN] proxy server closed, %v", err) // nolint gocritic
return nil return nil
} }
@ -424,6 +425,13 @@ func makeLBSelector() func(len int) int {
} }
} }
func makeOnlyFromMiddleware() *proxy.OnlyFrom {
if opts.RemoteLookupHeaders {
return proxy.NewOnlyFrom(proxy.OFRealIP, proxy.OFForwarded, proxy.OFRemoteAddr)
}
return proxy.NewOnlyFrom(proxy.OFRemoteAddr)
}
func makeErrorReporter() (proxy.Reporter, error) { func makeErrorReporter() (proxy.Reporter, error) {
result := &proxy.ErrorReporter{ result := &proxy.ErrorReporter{
Nice: opts.ErrorReport.Enabled, Nice: opts.ErrorReport.Enabled,

View File

@ -244,5 +244,4 @@ func TestHttp_basicAuthHandler(t *testing.T) {
require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, http.StatusOK, resp.StatusCode)
}) })
} }
} }

149
app/proxy/only_from.go Normal file
View File

@ -0,0 +1,149 @@
package proxy
import (
"bytes"
"net"
"net/http"
"strings"
"github.com/umputun/reproxy/app/discovery"
)
// OnlyFrom implements middleware to allow access for a limited list of source IPs.
type OnlyFrom struct {
lookups []OFLookup
}
// OFLookup defines lookup method for source IP.
type OFLookup string
// enum of possible lookup methods
const (
OFRemoteAddr OFLookup = "remote-addr"
OFRealIP OFLookup = "real-ip"
OFForwarded OFLookup = "forwarded"
)
// NewOnlyFrom creates OnlyFrom middleware with given lookup methods.
func NewOnlyFrom(lookups ...OFLookup) *OnlyFrom {
return &OnlyFrom{lookups: lookups}
}
// Handler implements middleware interface.
func (o *OnlyFrom) Handler(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
var allowedIPs []string
reqCtx := r.Context()
if reqCtx.Value(ctxMatch) != nil { // route match detected by matchHandler
match := reqCtx.Value(ctxMatch).(discovery.MatchedRoute)
allowedIPs = match.Mapper.OnlyFromIPs
}
if len(allowedIPs) == 0 {
// no restrictions if no ips defined
next.ServeHTTP(w, r)
return
}
realIP := o.realIP(o.lookups, r)
if realIP != "" && o.matchRemoteIP(realIP, allowedIPs) {
next.ServeHTTP(w, r)
return
}
w.WriteHeader(http.StatusForbidden)
}
return http.HandlerFunc(fn)
}
func (o *OnlyFrom) realIP(ipLookups []OFLookup, r *http.Request) string {
realIP := r.Header.Get("X-Real-IP")
forwardedFor := r.Header.Get("X-Forwarded-For")
for _, lookup := range ipLookups {
if lookup == OFRemoteAddr {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr // can't parse, return as is
}
return ip
}
if lookup == OFForwarded && forwardedFor != "" {
// X-Forwarded-For is potentially a list of addresses separated with ","
// The left-most being the original client, and each successive proxy that passed the request
// adding the IP address where it received the request from.
// In case if the original IP is a private behind a proxy, we need to get the first public IP from the list
return preferPublicIP(strings.Split(forwardedFor, ","))
}
if lookup == OFRealIP && realIP != "" {
return realIP
}
}
return "" // we can't get real ip
}
// matchRemoteIP returns true if request's ip matches any of ips in the list of allowedIPs.
// allowedIPs can be defined as IP (like 192.168.1.12) or CIDR (192.168.0.0/16)
func (o *OnlyFrom) matchRemoteIP(remoteIP string, allowedIPs []string) bool {
for _, allowedIP := range allowedIPs {
// check for ip prefix or CIDR
if _, cidrnet, err := net.ParseCIDR(allowedIP); err == nil {
if cidrnet.Contains(net.ParseIP(remoteIP)) {
return true
}
}
// check for ip match
if remoteIP == allowedIP {
return true
}
}
return false
}
// preferPublicIP returns first public IP from the list of IPs
// if no public IP found, returns first IP from the list
func preferPublicIP(ips []string) string {
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if net.ParseIP(ip).IsGlobalUnicast() && !isPrivateSubnet(net.ParseIP(ip)) {
return ip
}
}
return strings.TrimSpace(ips[0])
}
type ipRange struct {
start net.IP
end net.IP
}
var privateRanges = []ipRange{
{start: net.ParseIP("10.0.0.0"), end: net.ParseIP("10.255.255.255")},
{start: net.ParseIP("100.64.0.0"), end: net.ParseIP("100.127.255.255")},
{start: net.ParseIP("172.16.0.0"), end: net.ParseIP("172.31.255.255")},
{start: net.ParseIP("192.0.0.0"), end: net.ParseIP("192.0.0.255")},
{start: net.ParseIP("192.168.0.0"), end: net.ParseIP("192.168.255.255")},
{start: net.ParseIP("198.18.0.0"), end: net.ParseIP("198.19.255.255")},
{start: net.ParseIP("::1"), end: net.ParseIP("::1")},
{start: net.ParseIP("fc00::"), end: net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
{start: net.ParseIP("fe80::"), end: net.ParseIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
}
// isPrivateSubnet - check to see if this ip is in a private subnet
func isPrivateSubnet(ipAddress net.IP) bool {
inRange := func(r ipRange, ipAddress net.IP) bool {
// ensure the IPs are in the same format for comparison
ipAddress = ipAddress.To16()
r.start = r.start.To16()
r.end = r.end.To16()
return bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) <= 0
}
for _, r := range privateRanges {
if inRange(r, ipAddress) {
return true
}
}
return false
}

144
app/proxy/only_from_test.go Normal file
View File

@ -0,0 +1,144 @@
package proxy
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/umputun/reproxy/app/discovery"
)
func TestOnlyFrom_Handler(t *testing.T) {
tbl := []struct {
name string
lookups []OFLookup
allowedIPs []string
remoteAddr string
realIP string
forwardedFor string
expectedStatusCode int
}{
{
name: "allowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "192.168.1.1:1234",
expectedStatusCode: http.StatusOK,
},
{
name: "disallowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "192.168.1.2:1234",
expectedStatusCode: http.StatusForbidden,
},
{
name: "no restrictions",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{},
remoteAddr: "192.168.1.2:1234",
expectedStatusCode: http.StatusOK,
},
{
name: "allowed IP with RealIP lookup",
lookups: []OFLookup{OFRealIP},
allowedIPs: []string{"192.168.1.1"},
realIP: "192.168.1.1",
expectedStatusCode: http.StatusOK,
},
{
name: "disallowed IP with RealIP lookup",
lookups: []OFLookup{OFRealIP},
allowedIPs: []string{"192.168.1.1"},
realIP: "192.168.1.2",
expectedStatusCode: http.StatusForbidden,
},
{
name: "allowed IP with Forwarded lookup",
lookups: []OFLookup{OFForwarded},
allowedIPs: []string{"192.168.1.1"},
forwardedFor: "192.168.1.1",
expectedStatusCode: http.StatusOK,
},
{
name: "allowed IP with Forwarded lookup, mix private and public IPs",
lookups: []OFLookup{OFForwarded},
allowedIPs: []string{"8.8.8.8"},
forwardedFor: "192.168.1.1, 10.0.0.5, 8.8.8.8, 10.10.10.10",
expectedStatusCode: http.StatusOK,
},
{
name: "disallowed IP with Forwarded lookup",
lookups: []OFLookup{OFForwarded},
allowedIPs: []string{"192.168.1.1"},
forwardedFor: "192.168.1.2",
expectedStatusCode: http.StatusForbidden,
},
{
name: "multiple lookups, allowed IP",
lookups: []OFLookup{OFRemoteAddr, OFRealIP},
allowedIPs: []string{"192.168.1.1", "192.168.1.2"},
remoteAddr: "192.168.1.2:1234",
realIP: "192.168.1.1",
expectedStatusCode: http.StatusOK,
},
{
name: "multiple lookups, disallowed IP",
lookups: []OFLookup{OFRemoteAddr, OFRealIP},
allowedIPs: []string{"192.168.1.1", "192.168.1.2"},
remoteAddr: "192.168.1.3:1234",
realIP: "192.168.1.3",
expectedStatusCode: http.StatusForbidden,
},
{
name: "CIDR block, allowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.0/24"},
remoteAddr: "192.168.1.2:1234",
expectedStatusCode: http.StatusOK,
},
{
name: "CIDR block, disallowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.0/24"},
remoteAddr: "192.168.2.2:1234",
expectedStatusCode: http.StatusForbidden,
},
{
name: "invalid remote address format",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "invalid_format",
expectedStatusCode: http.StatusForbidden,
},
{
name: "empty remote address",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "",
expectedStatusCode: http.StatusForbidden,
},
}
for _, tt := range tbl {
t.Run(tt.name, func(t *testing.T) {
onlyFrom := NewOnlyFrom(tt.lookups...)
handler := onlyFrom.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
req := httptest.NewRequest("GET", "http://example.com/foo", http.NoBody)
req.RemoteAddr = tt.remoteAddr
req.Header.Set("X-Real-IP", tt.realIP)
req.Header.Set("X-Forwarded-For", tt.forwardedFor)
req = req.WithContext(context.WithValue(req.Context(),
ctxMatch, discovery.MatchedRoute{Mapper: discovery.URLMapper{OnlyFromIPs: tt.allowedIPs}}))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, tt.expectedStatusCode, rr.Code)
})
}
}

View File

@ -28,27 +28,27 @@ import (
// Http is a proxy server for both http and https // Http is a proxy server for both http and https
type Http struct { // nolint golint type Http struct { // nolint golint
Matcher Matcher
Address string Address string
AssetsLocation string AssetsLocation string
AssetsWebRoot string AssetsWebRoot string
Assets404 string Assets404 string
AssetsSPA bool AssetsSPA bool
MaxBodySize int64 MaxBodySize int64
GzEnabled bool GzEnabled bool
ProxyHeaders []string ProxyHeaders []string
DropHeader []string DropHeader []string
SSLConfig SSLConfig SSLConfig SSLConfig
Version string Version string
AccessLog io.Writer AccessLog io.Writer
StdOutEnabled bool StdOutEnabled bool
Signature bool Signature bool
Timeouts Timeouts Timeouts Timeouts
CacheControl MiddlewareProvider CacheControl MiddlewareProvider
Metrics MiddlewareProvider Metrics MiddlewareProvider
PluginConductor MiddlewareProvider PluginConductor MiddlewareProvider
Reporter Reporter Reporter Reporter
LBSelector func(len int) int LBSelector func(len int) int
OnlyFrom *OnlyFrom
BasicAuthEnabled bool BasicAuthEnabled bool
BasicAuthAllowed []string BasicAuthAllowed []string
@ -121,18 +121,19 @@ func (h *Http) Run(ctx context.Context) error {
}() }()
handler := R.Wrap(h.proxyHandler(), handler := R.Wrap(h.proxyHandler(),
R.Recoverer(log.Default()), // recover on errors R.Recoverer(log.Default()), // recover on errors
signatureHandler(h.Signature, h.Version), // send app signature signatureHandler(h.Signature, h.Version), // send app signature
h.pingHandler, // respond to /ping h.OnlyFrom.Handler, // limit source (remote) IPs if defined
h.pingHandler, // respond to /ping
basicAuthHandler(h.BasicAuthEnabled, h.BasicAuthAllowed), // basic auth basicAuthHandler(h.BasicAuthEnabled, h.BasicAuthAllowed), // basic auth
h.healthMiddleware, // respond to /health h.healthMiddleware, // respond to /health
h.matchHandler, // set matched routes to context h.matchHandler, // set matched routes to context
limiterSystemHandler(h.ThrottleSystem), // limit total requests/sec limiterSystemHandler(h.ThrottleSystem), // limit total requests/sec
limiterUserHandler(h.ThrottleUser), // req/seq per user/route match limiterUserHandler(h.ThrottleUser), // req/seq per user/route match
h.mgmtHandler(), // handles /metrics and /routes for prometheus h.mgmtHandler(), // handles /metrics and /routes for prometheus
h.pluginHandler(), // prc to external plugins h.pluginHandler(), // prc to external plugins
headersHandler(h.ProxyHeaders, h.DropHeader), // add response headers and delete some request headers headersHandler(h.ProxyHeaders, h.DropHeader), // add response headers and delete some request headers
accessLogHandler(h.AccessLog), // apache-format log file accessLogHandler(h.AccessLog), // apache-format log file
stdoutLogHandler(h.StdOutEnabled, logger.New(logger.Log(log.Default()), logger.Prefix("[INFO]")).Handler), stdoutLogHandler(h.StdOutEnabled, logger.New(logger.Log(log.Default()), logger.Prefix("[INFO]")).Handler),
maxReqSizeHandler(h.MaxBodySize), // limit request max size maxReqSizeHandler(h.MaxBodySize), // limit request max size
gzipHandler(h.GzEnabled), // gzip response gzipHandler(h.GzEnabled), // gzip response
@ -400,22 +401,22 @@ func (h *Http) makeHTTPServer(addr string, router http.Handler) *http.Server {
} }
func (h *Http) setXRealIP(r *http.Request) { func (h *Http) setXRealIP(r *http.Request) {
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
remoteIP := r.Header.Get("X-Forwarded-For") // use the left-most non-private client IP address
if remoteIP == "" { // if there is no any non-private IP address, use the left-most address
remoteIP = r.RemoteAddr r.Header.Set("X-Real-IP", preferPublicIP(strings.Split(forwarded, ",")))
}
ip, _, err := net.SplitHostPort(remoteIP)
if err != nil {
return return
} }
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return
}
userIP := net.ParseIP(ip) userIP := net.ParseIP(ip)
if userIP == nil { if userIP == nil {
return return
} }
r.Header.Add("X-Real-IP", ip) r.Header.Set("X-Real-IP", ip)
} }
// discoveredServers gets the list of servers discovered by providers. // discoveredServers gets the list of servers discovered by providers.

View File

@ -34,6 +34,7 @@ func TestHttp_Do(t *testing.T) {
t.Logf("req: %v", r) t.Logf("req: %v", r)
w.Header().Add("h1", "v1") w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP")) require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
require.Equal(t, "127.0.0.1", r.Header.Get("X-Forwarded-For"))
fmt.Fprintf(w, "response %s", r.URL.String()) fmt.Fprintf(w, "response %s", r.URL.String())
})) }))
@ -59,7 +60,7 @@ func TestHttp_Do(t *testing.T) {
client := http.Client{} client := http.Client{}
{ t.Run("to 127.0.0.1, good", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
@ -75,9 +76,9 @@ func TestHttp_Do(t *testing.T) {
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
assert.Equal(t, "vv1", resp.Header.Get("hh1")) assert.Equal(t, "vv1", resp.Header.Get("hh1"))
assert.Equal(t, "vv2", resp.Header.Get("hh2")) assert.Equal(t, "vv2", resp.Header.Get("hh2"))
} })
{ t.Run("to localhost, good", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/something") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/something")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -89,9 +90,9 @@ func TestHttp_Do(t *testing.T) {
assert.Equal(t, "response /123/something", string(body)) assert.Equal(t, "response /123/something", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name")) assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
} })
{ t.Run("bad gateway", func(t *testing.T) {
resp, err := client.Get("http://127.0.0.1:" + strconv.Itoa(port) + "/bad/something") resp, err := client.Get("http://127.0.0.1:" + strconv.Itoa(port) + "/bad/something")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -100,9 +101,9 @@ func TestHttp_Do(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, string(b), "Sorry for the inconvenience") assert.Contains(t, string(b), "Sorry for the inconvenience")
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
} })
{ t.Run("url encode", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -114,7 +115,7 @@ func TestHttp_Do(t *testing.T) {
assert.Equal(t, "response /123/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21", string(body)) assert.Equal(t, "response /123/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name")) assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
} })
} }
func TestHttp_DoWithAssets(t *testing.T) { func TestHttp_DoWithAssets(t *testing.T) {
@ -153,7 +154,7 @@ func TestHttp_DoWithAssets(t *testing.T) {
client := http.Client{} client := http.Client{}
{ t.Run("api call", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
@ -167,9 +168,9 @@ func TestHttp_DoWithAssets(t *testing.T) {
assert.Equal(t, "response /567/something", string(body)) assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
} })
{ t.Run("static call, good", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -182,9 +183,9 @@ func TestHttp_DoWithAssets(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
} })
{ t.Run("static call, bad", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -192,9 +193,9 @@ func TestHttp_DoWithAssets(t *testing.T) {
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "404 page not found\n", string(body)) assert.Equal(t, "404 page not found\n", string(body))
} })
{ t.Run("bad url", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -203,7 +204,7 @@ func TestHttp_DoWithAssets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, string(body), "Server error") assert.Contains(t, string(body), "Server error")
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type")) assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
} })
} }
func TestHttp_DoWithAssetsCustom404(t *testing.T) { func TestHttp_DoWithAssetsCustom404(t *testing.T) {
@ -243,7 +244,7 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
client := http.Client{} client := http.Client{}
{ t.Run("api call, found", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
@ -257,9 +258,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "response /567/something", string(body)) assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
} })
{ t.Run("static call, found", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -272,9 +273,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
} })
{ t.Run("static call, not found", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -284,9 +285,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body)) assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body))
t.Logf("%+v", resp.Header) t.Logf("%+v", resp.Header)
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
} })
{ t.Run("another static call, not found", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad2.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad2.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -296,7 +297,7 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body)) assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body))
t.Logf("%+v", resp.Header) t.Logf("%+v", resp.Header)
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
} })
} }
func TestHttp_DoWithSpaAssets(t *testing.T) { func TestHttp_DoWithSpaAssets(t *testing.T) {
@ -336,7 +337,7 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
client := http.Client{} client := http.Client{}
{ t.Run("api call, good", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
@ -350,9 +351,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
assert.Equal(t, "response /567/something", string(body)) assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
} })
{ t.Run("static call, good", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -365,9 +366,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
} })
{ t.Run("static call, not found server index", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -380,9 +381,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
} })
{ t.Run("static call, bad url", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -391,7 +392,7 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, string(body), "Server error") assert.Contains(t, string(body), "Server error")
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type")) assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
} })
} }
func TestHttp_DoWithAssetRules(t *testing.T) { func TestHttp_DoWithAssetRules(t *testing.T) {
@ -715,16 +716,16 @@ func TestHttp_withBasicAuth(t *testing.T) {
client := http.Client{} client := http.Client{}
{ t.Run("no auth", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
} })
{ t.Run("bad auth", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test", "badpasswd") req.SetBasicAuth("test", "badpasswd")
require.NoError(t, err) require.NoError(t, err)
@ -732,8 +733,9 @@ func TestHttp_withBasicAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
} })
{
t.Run("good auth", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test", "passwd") req.SetBasicAuth("test", "passwd")
require.NoError(t, err) require.NoError(t, err)
@ -741,8 +743,9 @@ func TestHttp_withBasicAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
} })
{
t.Run("good auth 2", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test2", "passwd2") req.SetBasicAuth("test2", "passwd2")
require.NoError(t, err) require.NoError(t, err)
@ -750,7 +753,7 @@ func TestHttp_withBasicAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
} })
} }
func TestHttp_toHttp(t *testing.T) { func TestHttp_toHttp(t *testing.T) {
@ -766,9 +769,9 @@ func TestHttp_toHttp(t *testing.T) {
} }
h := Http{} h := Http{}
for i, tt := range tbl { for _, tt := range tbl {
tt := tt tt := tt
t.Run(strconv.Itoa(i), func(t *testing.T) { t.Run(tt.addr, func(t *testing.T) {
assert.Equal(t, tt.res, h.toHTTP(tt.addr, tt.port)) assert.Equal(t, tt.res, h.toHTTP(tt.addr, tt.port))
}) })
} }
@ -791,8 +794,8 @@ func TestHttp_isAssetRequest(t *testing.T) {
{"/static/", "/tmp", "", false}, {"/static/", "/tmp", "", false},
} }
for i, tt := range tbl { for _, tt := range tbl {
t.Run(strconv.Itoa(i), func(t *testing.T) { t.Run(tt.req, func(t *testing.T) {
h := Http{AssetsLocation: tt.assetsLocation, AssetsWebRoot: tt.assetsWebRoot} h := Http{AssetsLocation: tt.assetsLocation, AssetsWebRoot: tt.assetsWebRoot}
r, err := http.NewRequest("GET", tt.req, http.NoBody) r, err := http.NewRequest("GET", tt.req, http.NoBody)
require.NoError(t, err) require.NoError(t, err)
@ -803,56 +806,61 @@ func TestHttp_isAssetRequest(t *testing.T) {
} }
func TestHttp_matchHandler(t *testing.T) { func TestHttp_matchHandler(t *testing.T) {
tbl := []struct { tbl := []struct {
name string
matches discovery.Matches matches discovery.Matches
res string res string
ok bool ok bool
}{ }{
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ name: "all alive destinations",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: true}, {Destination: "dest1", Alive: true},
{Destination: "dest2", Alive: true}, {Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: true}, {Destination: "dest3", Alive: true},
}}, }},
"dest1", true, res: "dest1", ok: true,
}, },
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ name: "second alive destination",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false}, {Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: true}, {Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: false}, {Destination: "dest3", Alive: false},
}}, }},
"dest2", true, res: "dest2", ok: true,
}, },
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ name: "one dead destination",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false}, {Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: true}, {Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: true}, {Destination: "dest3", Alive: true},
}}, }},
"dest2", true, res: "dest2", ok: true,
}, },
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ name: "last alive destination",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false}, {Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: false}, {Destination: "dest2", Alive: false},
{Destination: "dest3", Alive: true}, {Destination: "dest3", Alive: true},
}}, }},
"dest3", true, res: "dest3", ok: true,
}, },
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ name: "all dead destinations",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false}, {Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: false}, {Destination: "dest2", Alive: false},
{Destination: "dest3", Alive: false}, {Destination: "dest3", Alive: false},
}}, }},
"", false, res: "", ok: false,
}, },
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{}}, "", false, name: "no destinations",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{}}, res: "", ok: false,
}, },
} }
@ -864,9 +872,8 @@ func TestHttp_matchHandler(t *testing.T) {
} }
client := http.Client{} client := http.Client{}
for i, tt := range tbl { for _, tt := range tbl {
t.Run(strconv.Itoa(i), func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }} h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }}
handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %+v", r) t.Logf("req: %+v", r)
@ -893,7 +900,6 @@ func TestHttp_matchHandler(t *testing.T) {
} }
func TestHttp_discoveredServers(t *testing.T) { func TestHttp_discoveredServers(t *testing.T) {
calls := 0 calls := 0
m := &MatcherMock{ServersFunc: func() []string { m := &MatcherMock{ServersFunc: func() []string {
defer func() { calls++ }() defer func() { calls++ }()