1
0
mirror of https://github.com/umputun/reproxy.git synced 2025-06-30 22:13:42 +02:00

add middleware to optionally allow requests from giving ips/ranges

add new remote param to docker and file providers

lint: http nil body

add support of remote ips to consul provider

local implementation of onlyfrom middleware

lint: missing comment

make proxy tests more readable

preffer public IP if any forwwarded
This commit is contained in:
Umputun
2023-11-25 13:54:14 -06:00
parent 36a5443378
commit a896f08eec
16 changed files with 531 additions and 142 deletions

View File

@ -37,6 +37,7 @@ type URLMapper struct {
PingURL string PingURL string
MatchType MatchType MatchType MatchType
RedirectType RedirectType RedirectType RedirectType
OnlyFromIPs []string
AssetsLocation string // local FS root location AssetsLocation string // local FS root location
AssetsWebRoot string // web root location AssetsWebRoot string // web root location
@ -484,16 +485,6 @@ func (s *Service) mergeEvents(ctx context.Context, chs ...<-chan ProviderID) <-c
return out return out
} }
// Contains checks if the input string (e) in the given slice
func Contains(e string, s []string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
// IsAlive indicates whether mapper destination is alive // IsAlive indicates whether mapper destination is alive
func (m URLMapper) IsAlive() bool { func (m URLMapper) IsAlive() bool {
return !m.dead return !m.dead
@ -515,3 +506,24 @@ func (m URLMapper) ping() (string, error) {
return "", err return "", err
} }
// Contains checks if the input string (e) in the given slice
func Contains(e string, s []string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
// ParseOnlyFrom parses comma separated list of IPs
func ParseOnlyFrom(s string) (res []string) {
if s == "" {
return []string{}
}
for _, v := range strings.Split(s, ",") {
res = append(res, strings.TrimSpace(v))
}
return res
}

View File

@ -39,7 +39,7 @@ func TestService_Run(t *testing.T) {
ListFunc: func() ([]URLMapper, error) { ListFunc: func() ([]URLMapper, error) {
return []URLMapper{ return []URLMapper{
{Server: "localhost", SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), {Server: "localhost", SrcMatch: *regexp.MustCompile("/api/svc3/xyz"),
Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker}, Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker, OnlyFromIPs: []string{"127.0.0.1"}},
}, nil }, nil
}, },
} }
@ -66,6 +66,7 @@ func TestService_Run(t *testing.T) {
assert.Equal(t, "localhost", mappers[0].Server) assert.Equal(t, "localhost", mappers[0].Server)
assert.Equal(t, "/api/svc3/xyz", mappers[0].SrcMatch.String()) assert.Equal(t, "/api/svc3/xyz", mappers[0].SrcMatch.String())
assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", mappers[0].Dst) assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", mappers[0].Dst)
assert.Equal(t, []string{"127.0.0.1"}, mappers[0].OnlyFromIPs)
assert.Equal(t, 1, len(p1.EventsCalls())) assert.Equal(t, 1, len(p1.EventsCalls()))
assert.Equal(t, 1, len(p2.EventsCalls())) assert.Equal(t, 1, len(p2.EventsCalls()))
@ -104,7 +105,8 @@ func TestService_Match(t *testing.T) {
}, },
ListFunc: func() ([]URLMapper, error) { ListFunc: func() ([]URLMapper, error) {
return []URLMapper{ return []URLMapper{
{SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), Dst: "http://127.0.0.3:8080/blah3/xyz", ProviderID: PIDocker}, {SrcMatch: *regexp.MustCompile("/api/svc3/xyz"), Dst: "http://127.0.0.3:8080/blah3/xyz",
OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}, ProviderID: PIDocker},
{SrcMatch: *regexp.MustCompile("/web"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic, {SrcMatch: *regexp.MustCompile("/web"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic,
AssetsWebRoot: "/web", AssetsLocation: "/var/web"}, AssetsWebRoot: "/web", AssetsLocation: "/var/web"},
{SrcMatch: *regexp.MustCompile("/www/"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic, {SrcMatch: *regexp.MustCompile("/www/"), Dst: "/var/web", ProviderID: PIDocker, MatchType: MTStatic,
@ -131,9 +133,11 @@ func TestService_Match(t *testing.T) {
res Matches res Matches
}{ }{
{"example.com", "/api/svc3/xyz/something", Matches{MTProxy, []MatchedRoute{ {"example.com", "/api/svc3/xyz/something", Matches{MTProxy, []MatchedRoute{
{Destination: "http://127.0.0.3:8080/blah3/xyz/something", Alive: true}}}}, {Destination: "http://127.0.0.3:8080/blah3/xyz/something", Alive: true,
Mapper: URLMapper{OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}}}}}},
{"example.com", "/api/svc3/xyz", Matches{MTProxy, []MatchedRoute{{ {"example.com", "/api/svc3/xyz", Matches{MTProxy, []MatchedRoute{{
Destination: "http://127.0.0.3:8080/blah3/xyz", Alive: true}}}}, Destination: "http://127.0.0.3:8080/blah3/xyz", Alive: true,
Mapper: URLMapper{OnlyFromIPs: []string{"127.0.0.1", "192.168.1.0/24"}}}}}},
{"abc.example.com", "/api/svc1/1234", Matches{MTProxy, []MatchedRoute{ {"abc.example.com", "/api/svc1/1234", Matches{MTProxy, []MatchedRoute{
{Destination: "http://127.0.0.1:8080/blah1/1234", Alive: true}}}}, {Destination: "http://127.0.0.1:8080/blah1/1234", Alive: true}}}},
{"zzz.example.com", "/aaa/api/svc1/1234", Matches{MTProxy, nil}}, {"zzz.example.com", "/aaa/api/svc1/1234", Matches{MTProxy, nil}},
@ -167,6 +171,7 @@ func TestService_Match(t *testing.T) {
for i := 0; i < len(res.Routes); i++ { for i := 0; i < len(res.Routes); i++ {
assert.Equal(t, tt.res.Routes[i].Alive, res.Routes[i].Alive) assert.Equal(t, tt.res.Routes[i].Alive, res.Routes[i].Alive)
assert.Equal(t, tt.res.Routes[i].Destination, res.Routes[i].Destination) assert.Equal(t, tt.res.Routes[i].Destination, res.Routes[i].Destination)
assert.Equal(t, tt.res.Routes[i].Mapper.OnlyFromIPs, res.Routes[i].Mapper.OnlyFromIPs)
} }
assert.Equal(t, tt.res.MatchType, res.MatchType) assert.Equal(t, tt.res.MatchType, res.MatchType)
}) })
@ -608,3 +613,39 @@ func TestCheckHealth(t *testing.T) {
assert.NoError(t, res[ts.URL]) assert.NoError(t, res[ts.URL])
assert.NoError(t, res[ts2.URL]) assert.NoError(t, res[ts2.URL])
} }
func TestParseOnlyFrom(t *testing.T) {
tbl := []struct {
name string
input string
expected []string
}{
{
name: "empty string",
input: "",
expected: []string{},
},
{
name: "single IP",
input: "192.168.1.1",
expected: []string{"192.168.1.1"},
},
{
name: "multiple IPs",
input: "192.168.1.1, 192.168.1.2, 192.168.1.3, 10.0.0.0/16",
expected: []string{"192.168.1.1", "192.168.1.2", "192.168.1.3", "10.0.0.0/16"},
},
{
name: "multiple IPs with extra spaces",
input: " 192.168.1.1 , 192.168.1.2 , 192.168.1.3 ",
expected: []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"},
},
}
for _, tt := range tbl {
t.Run(tt.name, func(t *testing.T) {
result := ParseOnlyFrom(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@ -3,12 +3,13 @@ package consulcatalog
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/umputun/reproxy/app/discovery"
"log" "log"
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
"time" "time"
"github.com/umputun/reproxy/app/discovery"
) )
//go:generate moq -out consul_client_mock.go -skip-ensure -fmt goimports . ConsulClient //go:generate moq -out consul_client_mock.go -skip-ensure -fmt goimports . ConsulClient
@ -139,6 +140,7 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) {
destURL := fmt.Sprintf("http://%s:%d/$1", c.ServiceAddress, c.ServicePort) destURL := fmt.Sprintf("http://%s:%d/$1", c.ServiceAddress, c.ServicePort)
pingURL := fmt.Sprintf("http://%s:%d/ping", c.ServiceAddress, c.ServicePort) pingURL := fmt.Sprintf("http://%s:%d/ping", c.ServiceAddress, c.ServicePort)
server := "*" server := "*"
onlyFrom := []string{}
if v, ok := c.Labels["reproxy.enabled"]; ok && (v == "true" || v == "yes" || v == "1") { if v, ok := c.Labels["reproxy.enabled"]; ok && (v == "true" || v == "yes" || v == "1") {
enabled = true enabled = true
@ -159,6 +161,10 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) {
server = v server = v
} }
if v, ok := c.Labels["reproxy.remote"]; ok {
onlyFrom = discovery.ParseOnlyFrom(v)
}
if v, ok := c.Labels["reproxy.ping"]; ok { if v, ok := c.Labels["reproxy.ping"]; ok {
enabled = true enabled = true
pingURL = fmt.Sprintf("http://%s:%d%s", c.ServiceAddress, c.ServicePort, v) pingURL = fmt.Sprintf("http://%s:%d%s", c.ServiceAddress, c.ServicePort, v)
@ -177,7 +183,7 @@ func (cc *ConsulCatalog) List() ([]discovery.URLMapper, error) {
// server label may have multiple, comma separated servers // server label may have multiple, comma separated servers
for _, srv := range strings.Split(server, ",") { for _, srv := range strings.Split(server, ",") {
res = append(res, discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL, res = append(res, discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL,
PingURL: pingURL, ProviderID: discovery.PIConsulCatalog}) OnlyFromIPs: onlyFrom, PingURL: pingURL, ProviderID: discovery.PIConsulCatalog})
} }
} }

View File

@ -3,12 +3,14 @@ package consulcatalog
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/umputun/reproxy/app/discovery"
"sort" "sort"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/umputun/reproxy/app/discovery"
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
@ -62,7 +64,8 @@ func TestConsulCatalog_List(t *testing.T) {
ServiceAddress: "addr3", ServiceAddress: "addr3",
ServicePort: 3000, ServicePort: 3000,
Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1", Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1",
"reproxy.server": "example.com,domain.com", "reproxy.ping": "/ping", "reproxy.enabled": "yes"}, "reproxy.server": "example.com,domain.com", "reproxy.ping": "/ping",
"reproxy.enabled": "yes", "reproxy.remote": "127.0.0.1, 192.168.1.0/24"},
}, },
{ {
ServiceID: "id4", ServiceID: "id4",
@ -91,21 +94,25 @@ func TestConsulCatalog_List(t *testing.T) {
assert.Equal(t, "http://addr3:3000/blah/$1", res[0].Dst) assert.Equal(t, "http://addr3:3000/blah/$1", res[0].Dst)
assert.Equal(t, "example.com", res[0].Server) assert.Equal(t, "example.com", res[0].Server)
assert.Equal(t, "http://addr3:3000/ping", res[0].PingURL) assert.Equal(t, "http://addr3:3000/ping", res[0].PingURL)
assert.Equal(t, []string{"127.0.0.1", "192.168.1.0/24"}, res[0].OnlyFromIPs)
assert.Equal(t, "^/api/123/(.*)", res[1].SrcMatch.String()) assert.Equal(t, "^/api/123/(.*)", res[1].SrcMatch.String())
assert.Equal(t, "http://addr3:3000/blah/$1", res[1].Dst) assert.Equal(t, "http://addr3:3000/blah/$1", res[1].Dst)
assert.Equal(t, "domain.com", res[1].Server) assert.Equal(t, "domain.com", res[1].Server)
assert.Equal(t, "http://addr3:3000/ping", res[1].PingURL) assert.Equal(t, "http://addr3:3000/ping", res[1].PingURL)
assert.Equal(t, []string{"127.0.0.1", "192.168.1.0/24"}, res[1].OnlyFromIPs)
assert.Equal(t, "^/(.*)", res[2].SrcMatch.String()) assert.Equal(t, "^/(.*)", res[2].SrcMatch.String())
assert.Equal(t, "http://addr44:4000/$1", res[2].Dst) assert.Equal(t, "http://addr44:4000/$1", res[2].Dst)
assert.Equal(t, "http://addr44:4000/ping", res[2].PingURL) assert.Equal(t, "http://addr44:4000/ping", res[2].PingURL)
assert.Equal(t, "*", res[2].Server) assert.Equal(t, "*", res[2].Server)
assert.Equal(t, []string{}, res[2].OnlyFromIPs)
assert.Equal(t, "^/(.*)", res[3].SrcMatch.String()) assert.Equal(t, "^/(.*)", res[3].SrcMatch.String())
assert.Equal(t, "http://addr2:2000/$1", res[3].Dst) assert.Equal(t, "http://addr2:2000/$1", res[3].Dst)
assert.Equal(t, "http://addr2:2000/ping", res[3].PingURL) assert.Equal(t, "http://addr2:2000/ping", res[3].PingURL)
assert.Equal(t, "*", res[3].Server) assert.Equal(t, "*", res[3].Server)
assert.Equal(t, []string{}, res[3].OnlyFromIPs)
} }
func TestConsulCatalog_serviceListWasChanged(t *testing.T) { func TestConsulCatalog_serviceListWasChanged(t *testing.T) {

View File

@ -103,6 +103,7 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper)
// defaults // defaults
destURL, pingURL, server := fmt.Sprintf("http://%s:%d/$1", c.IP, port), fmt.Sprintf("http://%s:%d/ping", c.IP, port), "*" destURL, pingURL, server := fmt.Sprintf("http://%s:%d/$1", c.IP, port), fmt.Sprintf("http://%s:%d/ping", c.IP, port), "*"
assetsWebRoot, assetsLocation, assetsSPA := "", "", false assetsWebRoot, assetsLocation, assetsSPA := "", "", false
onlyFrom := []string{}
if d.AutoAPI && n == 0 { if d.AutoAPI && n == 0 {
enabled = true enabled = true
@ -133,6 +134,10 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper)
server = v server = v
} }
if v, ok := d.labelN(c.Labels, n, "remote"); ok {
onlyFrom = discovery.ParseOnlyFrom(v)
}
if v, ok := d.labelN(c.Labels, n, "ping"); ok { if v, ok := d.labelN(c.Labels, n, "ping"); ok {
enabled = true enabled = true
if strings.HasPrefix(v, "http://") || strings.HasPrefix(v, "https://") { if strings.HasPrefix(v, "http://") || strings.HasPrefix(v, "https://") {
@ -171,7 +176,7 @@ func (d *Docker) parseContainerInfo(c containerInfo) (res []discovery.URLMapper)
// docker server label may have multiple, comma separated servers // docker server label may have multiple, comma separated servers
for _, srv := range strings.Split(server, ",") { for _, srv := range strings.Split(server, ",") {
mp := discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL, mp := discovery.URLMapper{Server: strings.TrimSpace(srv), SrcMatch: *srcRegex, Dst: destURL,
PingURL: pingURL, ProviderID: discovery.PIDocker, MatchType: discovery.MTProxy} PingURL: pingURL, OnlyFromIPs: onlyFrom, ProviderID: discovery.PIDocker, MatchType: discovery.MTProxy}
// for assets we add the second proxy mapping only if explicitly requested // for assets we add the second proxy mapping only if explicitly requested
if assetsWebRoot != "" && explicit { if assetsWebRoot != "" && explicit {

View File

@ -30,7 +30,7 @@ func TestDocker_List(t *testing.T) {
{ {
Name: "c1", State: "running", IP: "127.0.0.2", Ports: []int{12345}, Name: "c1", State: "running", IP: "127.0.0.2", Ports: []int{12345},
Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1", Labels: map[string]string{"reproxy.route": "^/api/123/(.*)", "reproxy.dest": "/blah/$1",
"reproxy.server": "example.com", "reproxy.ping": "/ping"}, "reproxy.server": "example.com", "reproxy.ping": "/ping", "reproxy.remote": "192.168.1.0/24, 127.0.0.1"},
}, },
{ {
Name: "c1", State: "running", IP: "127.0.0.21", Ports: []int{12345}, Name: "c1", State: "running", IP: "127.0.0.21", Ports: []int{12345},
@ -64,21 +64,25 @@ func TestDocker_List(t *testing.T) {
assert.Equal(t, "http://127.0.0.2:12345/blah/$1", res[0].Dst) assert.Equal(t, "http://127.0.0.2:12345/blah/$1", res[0].Dst)
assert.Equal(t, "example.com", res[0].Server) assert.Equal(t, "example.com", res[0].Server)
assert.Equal(t, "http://127.0.0.2:12345/ping", res[0].PingURL) assert.Equal(t, "http://127.0.0.2:12345/ping", res[0].PingURL)
assert.Equal(t, []string{"192.168.1.0/24", "127.0.0.1"}, res[0].OnlyFromIPs)
assert.Equal(t, "^/api/90/(.*)", res[1].SrcMatch.String()) assert.Equal(t, "^/api/90/(.*)", res[1].SrcMatch.String())
assert.Equal(t, "http://example.com/blah/$1", res[1].Dst) assert.Equal(t, "http://example.com/blah/$1", res[1].Dst)
assert.Equal(t, "https://example.com//ping", res[1].PingURL) assert.Equal(t, "https://example.com//ping", res[1].PingURL)
assert.Equal(t, "example.com", res[1].Server) assert.Equal(t, "example.com", res[1].Server)
assert.Equal(t, []string{}, res[1].OnlyFromIPs)
assert.Equal(t, "^/c2/(.*)", res[2].SrcMatch.String()) assert.Equal(t, "^/c2/(.*)", res[2].SrcMatch.String())
assert.Equal(t, "http://127.0.0.3:12346/$1", res[2].Dst) assert.Equal(t, "http://127.0.0.3:12346/$1", res[2].Dst)
assert.Equal(t, "http://127.0.0.3:12346/ping", res[2].PingURL) assert.Equal(t, "http://127.0.0.3:12346/ping", res[2].PingURL)
assert.Equal(t, "*", res[2].Server) assert.Equal(t, "*", res[2].Server)
assert.Equal(t, []string{}, res[2].OnlyFromIPs)
assert.Equal(t, "^/a/(.*)", res[3].SrcMatch.String()) assert.Equal(t, "^/a/(.*)", res[3].SrcMatch.String())
assert.Equal(t, "http://127.0.0.2:12348/a/$1", res[3].Dst) assert.Equal(t, "http://127.0.0.2:12348/a/$1", res[3].Dst)
assert.Equal(t, "http://127.0.0.2:12348/ping", res[3].PingURL) assert.Equal(t, "http://127.0.0.2:12348/ping", res[3].PingURL)
assert.Equal(t, "example.com", res[3].Server) assert.Equal(t, "example.com", res[3].Server)
assert.Equal(t, []string{}, res[3].OnlyFromIPs)
} }
func TestDocker_ListMulti(t *testing.T) { func TestDocker_ListMulti(t *testing.T) {

View File

@ -84,6 +84,7 @@ func (d *File) List() (res []discovery.URLMapper, err error) {
Ping string `yaml:"ping"` Ping string `yaml:"ping"`
AssetsEnabled bool `yaml:"assets"` AssetsEnabled bool `yaml:"assets"`
AssetsSPA bool `yaml:"spa"` AssetsSPA bool `yaml:"spa"`
OnlyFrom string `yaml:"remote"`
} }
fh, err := os.Open(d.FileName) fh, err := os.Open(d.FileName)
if err != nil { if err != nil {
@ -112,6 +113,7 @@ func (d *File) List() (res []discovery.URLMapper, err error) {
PingURL: f.Ping, PingURL: f.Ping,
ProviderID: discovery.PIFile, ProviderID: discovery.PIFile,
MatchType: discovery.MTProxy, MatchType: discovery.MTProxy,
OnlyFromIPs: discovery.ParseOnlyFrom(f.OnlyFrom),
} }
if f.AssetsEnabled || f.AssetsSPA { if f.AssetsEnabled || f.AssetsSPA {
mapper.MatchType = discovery.MTStatic mapper.MatchType = discovery.MTStatic

View File

@ -113,18 +113,21 @@ func TestFile_List(t *testing.T) {
assert.Equal(t, "", res[0].PingURL) assert.Equal(t, "", res[0].PingURL)
assert.Equal(t, "srv.example.com", res[0].Server) assert.Equal(t, "srv.example.com", res[0].Server)
assert.Equal(t, discovery.MTProxy, res[0].MatchType) assert.Equal(t, discovery.MTProxy, res[0].MatchType)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
assert.Equal(t, "^/api/svc1/(.*)", res[1].SrcMatch.String()) assert.Equal(t, "^/api/svc1/(.*)", res[1].SrcMatch.String())
assert.Equal(t, "http://127.0.0.1:8080/blah1/$1", res[1].Dst) assert.Equal(t, "http://127.0.0.1:8080/blah1/$1", res[1].Dst)
assert.Equal(t, "", res[1].PingURL) assert.Equal(t, "", res[1].PingURL)
assert.Equal(t, "*", res[1].Server) assert.Equal(t, "*", res[1].Server)
assert.Equal(t, discovery.MTProxy, res[1].MatchType) assert.Equal(t, discovery.MTProxy, res[1].MatchType)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
assert.Equal(t, "/api/svc3/xyz", res[2].SrcMatch.String()) assert.Equal(t, "/api/svc3/xyz", res[2].SrcMatch.String())
assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", res[2].Dst) assert.Equal(t, "http://127.0.0.3:8080/blah3/xyz", res[2].Dst)
assert.Equal(t, "http://127.0.0.3:8080/ping", res[2].PingURL) assert.Equal(t, "http://127.0.0.3:8080/ping", res[2].PingURL)
assert.Equal(t, "*", res[2].Server) assert.Equal(t, "*", res[2].Server)
assert.Equal(t, discovery.MTProxy, res[2].MatchType) assert.Equal(t, discovery.MTProxy, res[2].MatchType)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
assert.Equal(t, "/web/", res[3].SrcMatch.String()) assert.Equal(t, "/web/", res[3].SrcMatch.String())
assert.Equal(t, "/var/web", res[3].Dst) assert.Equal(t, "/var/web", res[3].Dst)
@ -132,6 +135,7 @@ func TestFile_List(t *testing.T) {
assert.Equal(t, "*", res[3].Server) assert.Equal(t, "*", res[3].Server)
assert.Equal(t, discovery.MTStatic, res[3].MatchType) assert.Equal(t, discovery.MTStatic, res[3].MatchType)
assert.Equal(t, false, res[3].AssetsSPA) assert.Equal(t, false, res[3].AssetsSPA)
assert.Equal(t, []string{"192.168.1.0/24", "124.0.0.1"}, res[3].OnlyFromIPs)
assert.Equal(t, "/web2/", res[4].SrcMatch.String()) assert.Equal(t, "/web2/", res[4].SrcMatch.String())
assert.Equal(t, "/var/web2", res[4].Dst) assert.Equal(t, "/var/web2", res[4].Dst)
@ -139,4 +143,5 @@ func TestFile_List(t *testing.T) {
assert.Equal(t, "*", res[4].Server) assert.Equal(t, "*", res[4].Server)
assert.Equal(t, discovery.MTStatic, res[4].MatchType) assert.Equal(t, discovery.MTStatic, res[4].MatchType)
assert.Equal(t, true, res[4].AssetsSPA) assert.Equal(t, true, res[4].AssetsSPA)
assert.Equal(t, []string{}, res[0].OnlyFromIPs)
} }

View File

@ -9,7 +9,7 @@ import (
"github.com/umputun/reproxy/app/discovery" "github.com/umputun/reproxy/app/discovery"
) )
// Static provider, rules are server,from,to // Static provider, rules are server,source_url,destination[,ping]
type Static struct { type Static struct {
Rules []string // each rule is 4 elements comma separated - server,source_url,destination,ping Rules []string // each rule is 4 elements comma separated - server,source_url,destination,ping
} }

View File

@ -1,7 +1,7 @@
default: default:
- {route: "^/api/svc1/(.*)", dest: "http://127.0.0.1:8080/blah1/$1"} - {route: "^/api/svc1/(.*)", dest: "http://127.0.0.1:8080/blah1/$1"}
- {route: "/api/svc3/xyz", dest: "http://127.0.0.3:8080/blah3/xyz", "ping": "http://127.0.0.3:8080/ping"} - {route: "/api/svc3/xyz", dest: "http://127.0.0.3:8080/blah3/xyz", "ping": "http://127.0.0.3:8080/ping"}
- {route: "/web/", dest: "/var/web", "assets": yes} - {route: "/web/", dest: "/var/web", "assets": yes, "remote": "192.168.1.0/24, 124.0.0.1"}
- {route: "/web2/", dest: "/var/web2", "spa": yes} - {route: "/web2/", dest: "/var/web2", "spa": yes}
srv.example.com: srv.example.com:
- {route: "^/api/svc2/(.*)", dest: "http://127.0.0.2:8080/blah2/$1/abc"} - {route: "^/api/svc2/(.*)", dest: "http://127.0.0.2:8080/blah2/$1/abc"}

View File

@ -35,7 +35,7 @@ var opts struct {
ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside "" ProxyHeaders []string `short:"x" long:"header" description:"outgoing proxy headers to add"` // env HEADER split in code to allow , inside ""
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","` DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"` AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"`
RemoteLookupHeaders bool `long:"remote-lookup-headers" env:"REMOTE_LOOKUP_HEADERS" description:"enable remote lookup headers"`
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint
SSL struct { SSL struct {
@ -273,10 +273,11 @@ func run() error {
ThrottleUser: opts.Throttle.User, ThrottleUser: opts.Throttle.User,
BasicAuthEnabled: len(basicAuthAllowed) > 0, BasicAuthEnabled: len(basicAuthAllowed) > 0,
BasicAuthAllowed: basicAuthAllowed, BasicAuthAllowed: basicAuthAllowed,
OnlyFrom: makeOnlyFromMiddleware(),
} }
err = px.Run(ctx) err = px.Run(ctx)
if err != nil && err == http.ErrServerClosed { if err != nil && errors.Is(err, http.ErrServerClosed) {
log.Printf("[WARN] proxy server closed, %v", err) // nolint gocritic log.Printf("[WARN] proxy server closed, %v", err) // nolint gocritic
return nil return nil
} }
@ -424,6 +425,13 @@ func makeLBSelector() func(len int) int {
} }
} }
func makeOnlyFromMiddleware() *proxy.OnlyFrom {
if opts.RemoteLookupHeaders {
return proxy.NewOnlyFrom(proxy.OFRealIP, proxy.OFForwarded, proxy.OFRemoteAddr)
}
return proxy.NewOnlyFrom(proxy.OFRemoteAddr)
}
func makeErrorReporter() (proxy.Reporter, error) { func makeErrorReporter() (proxy.Reporter, error) {
result := &proxy.ErrorReporter{ result := &proxy.ErrorReporter{
Nice: opts.ErrorReport.Enabled, Nice: opts.ErrorReport.Enabled,

View File

@ -244,5 +244,4 @@ func TestHttp_basicAuthHandler(t *testing.T) {
require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, http.StatusOK, resp.StatusCode)
}) })
} }
} }

149
app/proxy/only_from.go Normal file
View File

@ -0,0 +1,149 @@
package proxy
import (
"bytes"
"net"
"net/http"
"strings"
"github.com/umputun/reproxy/app/discovery"
)
// OnlyFrom implements middleware to allow access for a limited list of source IPs.
type OnlyFrom struct {
lookups []OFLookup
}
// OFLookup defines lookup method for source IP.
type OFLookup string
// enum of possible lookup methods
const (
OFRemoteAddr OFLookup = "remote-addr"
OFRealIP OFLookup = "real-ip"
OFForwarded OFLookup = "forwarded"
)
// NewOnlyFrom creates OnlyFrom middleware with given lookup methods.
func NewOnlyFrom(lookups ...OFLookup) *OnlyFrom {
return &OnlyFrom{lookups: lookups}
}
// Handler implements middleware interface.
func (o *OnlyFrom) Handler(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
var allowedIPs []string
reqCtx := r.Context()
if reqCtx.Value(ctxMatch) != nil { // route match detected by matchHandler
match := reqCtx.Value(ctxMatch).(discovery.MatchedRoute)
allowedIPs = match.Mapper.OnlyFromIPs
}
if len(allowedIPs) == 0 {
// no restrictions if no ips defined
next.ServeHTTP(w, r)
return
}
realIP := o.realIP(o.lookups, r)
if realIP != "" && o.matchRemoteIP(realIP, allowedIPs) {
next.ServeHTTP(w, r)
return
}
w.WriteHeader(http.StatusForbidden)
}
return http.HandlerFunc(fn)
}
func (o *OnlyFrom) realIP(ipLookups []OFLookup, r *http.Request) string {
realIP := r.Header.Get("X-Real-IP")
forwardedFor := r.Header.Get("X-Forwarded-For")
for _, lookup := range ipLookups {
if lookup == OFRemoteAddr {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr // can't parse, return as is
}
return ip
}
if lookup == OFForwarded && forwardedFor != "" {
// X-Forwarded-For is potentially a list of addresses separated with ","
// The left-most being the original client, and each successive proxy that passed the request
// adding the IP address where it received the request from.
// In case if the original IP is a private behind a proxy, we need to get the first public IP from the list
return preferPublicIP(strings.Split(forwardedFor, ","))
}
if lookup == OFRealIP && realIP != "" {
return realIP
}
}
return "" // we can't get real ip
}
// matchRemoteIP returns true if request's ip matches any of ips in the list of allowedIPs.
// allowedIPs can be defined as IP (like 192.168.1.12) or CIDR (192.168.0.0/16)
func (o *OnlyFrom) matchRemoteIP(remoteIP string, allowedIPs []string) bool {
for _, allowedIP := range allowedIPs {
// check for ip prefix or CIDR
if _, cidrnet, err := net.ParseCIDR(allowedIP); err == nil {
if cidrnet.Contains(net.ParseIP(remoteIP)) {
return true
}
}
// check for ip match
if remoteIP == allowedIP {
return true
}
}
return false
}
// preferPublicIP returns first public IP from the list of IPs
// if no public IP found, returns first IP from the list
func preferPublicIP(ips []string) string {
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if net.ParseIP(ip).IsGlobalUnicast() && !isPrivateSubnet(net.ParseIP(ip)) {
return ip
}
}
return strings.TrimSpace(ips[0])
}
type ipRange struct {
start net.IP
end net.IP
}
var privateRanges = []ipRange{
{start: net.ParseIP("10.0.0.0"), end: net.ParseIP("10.255.255.255")},
{start: net.ParseIP("100.64.0.0"), end: net.ParseIP("100.127.255.255")},
{start: net.ParseIP("172.16.0.0"), end: net.ParseIP("172.31.255.255")},
{start: net.ParseIP("192.0.0.0"), end: net.ParseIP("192.0.0.255")},
{start: net.ParseIP("192.168.0.0"), end: net.ParseIP("192.168.255.255")},
{start: net.ParseIP("198.18.0.0"), end: net.ParseIP("198.19.255.255")},
{start: net.ParseIP("::1"), end: net.ParseIP("::1")},
{start: net.ParseIP("fc00::"), end: net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
{start: net.ParseIP("fe80::"), end: net.ParseIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
}
// isPrivateSubnet - check to see if this ip is in a private subnet
func isPrivateSubnet(ipAddress net.IP) bool {
inRange := func(r ipRange, ipAddress net.IP) bool {
// ensure the IPs are in the same format for comparison
ipAddress = ipAddress.To16()
r.start = r.start.To16()
r.end = r.end.To16()
return bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) <= 0
}
for _, r := range privateRanges {
if inRange(r, ipAddress) {
return true
}
}
return false
}

144
app/proxy/only_from_test.go Normal file
View File

@ -0,0 +1,144 @@
package proxy
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/umputun/reproxy/app/discovery"
)
func TestOnlyFrom_Handler(t *testing.T) {
tbl := []struct {
name string
lookups []OFLookup
allowedIPs []string
remoteAddr string
realIP string
forwardedFor string
expectedStatusCode int
}{
{
name: "allowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "192.168.1.1:1234",
expectedStatusCode: http.StatusOK,
},
{
name: "disallowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "192.168.1.2:1234",
expectedStatusCode: http.StatusForbidden,
},
{
name: "no restrictions",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{},
remoteAddr: "192.168.1.2:1234",
expectedStatusCode: http.StatusOK,
},
{
name: "allowed IP with RealIP lookup",
lookups: []OFLookup{OFRealIP},
allowedIPs: []string{"192.168.1.1"},
realIP: "192.168.1.1",
expectedStatusCode: http.StatusOK,
},
{
name: "disallowed IP with RealIP lookup",
lookups: []OFLookup{OFRealIP},
allowedIPs: []string{"192.168.1.1"},
realIP: "192.168.1.2",
expectedStatusCode: http.StatusForbidden,
},
{
name: "allowed IP with Forwarded lookup",
lookups: []OFLookup{OFForwarded},
allowedIPs: []string{"192.168.1.1"},
forwardedFor: "192.168.1.1",
expectedStatusCode: http.StatusOK,
},
{
name: "allowed IP with Forwarded lookup, mix private and public IPs",
lookups: []OFLookup{OFForwarded},
allowedIPs: []string{"8.8.8.8"},
forwardedFor: "192.168.1.1, 10.0.0.5, 8.8.8.8, 10.10.10.10",
expectedStatusCode: http.StatusOK,
},
{
name: "disallowed IP with Forwarded lookup",
lookups: []OFLookup{OFForwarded},
allowedIPs: []string{"192.168.1.1"},
forwardedFor: "192.168.1.2",
expectedStatusCode: http.StatusForbidden,
},
{
name: "multiple lookups, allowed IP",
lookups: []OFLookup{OFRemoteAddr, OFRealIP},
allowedIPs: []string{"192.168.1.1", "192.168.1.2"},
remoteAddr: "192.168.1.2:1234",
realIP: "192.168.1.1",
expectedStatusCode: http.StatusOK,
},
{
name: "multiple lookups, disallowed IP",
lookups: []OFLookup{OFRemoteAddr, OFRealIP},
allowedIPs: []string{"192.168.1.1", "192.168.1.2"},
remoteAddr: "192.168.1.3:1234",
realIP: "192.168.1.3",
expectedStatusCode: http.StatusForbidden,
},
{
name: "CIDR block, allowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.0/24"},
remoteAddr: "192.168.1.2:1234",
expectedStatusCode: http.StatusOK,
},
{
name: "CIDR block, disallowed IP",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.0/24"},
remoteAddr: "192.168.2.2:1234",
expectedStatusCode: http.StatusForbidden,
},
{
name: "invalid remote address format",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "invalid_format",
expectedStatusCode: http.StatusForbidden,
},
{
name: "empty remote address",
lookups: []OFLookup{OFRemoteAddr},
allowedIPs: []string{"192.168.1.1"},
remoteAddr: "",
expectedStatusCode: http.StatusForbidden,
},
}
for _, tt := range tbl {
t.Run(tt.name, func(t *testing.T) {
onlyFrom := NewOnlyFrom(tt.lookups...)
handler := onlyFrom.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
req := httptest.NewRequest("GET", "http://example.com/foo", http.NoBody)
req.RemoteAddr = tt.remoteAddr
req.Header.Set("X-Real-IP", tt.realIP)
req.Header.Set("X-Forwarded-For", tt.forwardedFor)
req = req.WithContext(context.WithValue(req.Context(),
ctxMatch, discovery.MatchedRoute{Mapper: discovery.URLMapper{OnlyFromIPs: tt.allowedIPs}}))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, tt.expectedStatusCode, rr.Code)
})
}
}

View File

@ -48,7 +48,7 @@ type Http struct { // nolint golint
PluginConductor MiddlewareProvider PluginConductor MiddlewareProvider
Reporter Reporter Reporter Reporter
LBSelector func(len int) int LBSelector func(len int) int
OnlyFrom *OnlyFrom
BasicAuthEnabled bool BasicAuthEnabled bool
BasicAuthAllowed []string BasicAuthAllowed []string
@ -123,6 +123,7 @@ func (h *Http) Run(ctx context.Context) error {
handler := R.Wrap(h.proxyHandler(), handler := R.Wrap(h.proxyHandler(),
R.Recoverer(log.Default()), // recover on errors R.Recoverer(log.Default()), // recover on errors
signatureHandler(h.Signature, h.Version), // send app signature signatureHandler(h.Signature, h.Version), // send app signature
h.OnlyFrom.Handler, // limit source (remote) IPs if defined
h.pingHandler, // respond to /ping h.pingHandler, // respond to /ping
basicAuthHandler(h.BasicAuthEnabled, h.BasicAuthAllowed), // basic auth basicAuthHandler(h.BasicAuthEnabled, h.BasicAuthAllowed), // basic auth
h.healthMiddleware, // respond to /health h.healthMiddleware, // respond to /health
@ -400,22 +401,22 @@ func (h *Http) makeHTTPServer(addr string, router http.Handler) *http.Server {
} }
func (h *Http) setXRealIP(r *http.Request) { func (h *Http) setXRealIP(r *http.Request) {
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
remoteIP := r.Header.Get("X-Forwarded-For") // use the left-most non-private client IP address
if remoteIP == "" { // if there is no any non-private IP address, use the left-most address
remoteIP = r.RemoteAddr r.Header.Set("X-Real-IP", preferPublicIP(strings.Split(forwarded, ",")))
}
ip, _, err := net.SplitHostPort(remoteIP)
if err != nil {
return return
} }
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return
}
userIP := net.ParseIP(ip) userIP := net.ParseIP(ip)
if userIP == nil { if userIP == nil {
return return
} }
r.Header.Add("X-Real-IP", ip) r.Header.Set("X-Real-IP", ip)
} }
// discoveredServers gets the list of servers discovered by providers. // discoveredServers gets the list of servers discovered by providers.

View File

@ -34,6 +34,7 @@ func TestHttp_Do(t *testing.T) {
t.Logf("req: %v", r) t.Logf("req: %v", r)
w.Header().Add("h1", "v1") w.Header().Add("h1", "v1")
require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP")) require.Equal(t, "127.0.0.1", r.Header.Get("X-Real-IP"))
require.Equal(t, "127.0.0.1", r.Header.Get("X-Forwarded-For"))
fmt.Fprintf(w, "response %s", r.URL.String()) fmt.Fprintf(w, "response %s", r.URL.String())
})) }))
@ -59,7 +60,7 @@ func TestHttp_Do(t *testing.T) {
client := http.Client{} client := http.Client{}
{ t.Run("to 127.0.0.1, good", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
@ -75,9 +76,9 @@ func TestHttp_Do(t *testing.T) {
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
assert.Equal(t, "vv1", resp.Header.Get("hh1")) assert.Equal(t, "vv1", resp.Header.Get("hh1"))
assert.Equal(t, "vv2", resp.Header.Get("hh2")) assert.Equal(t, "vv2", resp.Header.Get("hh2"))
} })
{ t.Run("to localhost, good", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/something") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/something")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -89,9 +90,9 @@ func TestHttp_Do(t *testing.T) {
assert.Equal(t, "response /123/something", string(body)) assert.Equal(t, "response /123/something", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name")) assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
} })
{ t.Run("bad gateway", func(t *testing.T) {
resp, err := client.Get("http://127.0.0.1:" + strconv.Itoa(port) + "/bad/something") resp, err := client.Get("http://127.0.0.1:" + strconv.Itoa(port) + "/bad/something")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -100,9 +101,9 @@ func TestHttp_Do(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, string(b), "Sorry for the inconvenience") assert.Contains(t, string(b), "Sorry for the inconvenience")
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
} })
{ t.Run("url encode", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/api/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -114,7 +115,7 @@ func TestHttp_Do(t *testing.T) {
assert.Equal(t, "response /123/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21", string(body)) assert.Equal(t, "response /123/test%20%25%20and%20&,%20and%20other%20characters%20@%28%29%5E%21", string(body))
assert.Equal(t, "reproxy", resp.Header.Get("App-Name")) assert.Equal(t, "reproxy", resp.Header.Get("App-Name"))
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
} })
} }
func TestHttp_DoWithAssets(t *testing.T) { func TestHttp_DoWithAssets(t *testing.T) {
@ -153,7 +154,7 @@ func TestHttp_DoWithAssets(t *testing.T) {
client := http.Client{} client := http.Client{}
{ t.Run("api call", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
@ -167,9 +168,9 @@ func TestHttp_DoWithAssets(t *testing.T) {
assert.Equal(t, "response /567/something", string(body)) assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
} })
{ t.Run("static call, good", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -182,9 +183,9 @@ func TestHttp_DoWithAssets(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
} })
{ t.Run("static call, bad", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -192,9 +193,9 @@ func TestHttp_DoWithAssets(t *testing.T) {
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "404 page not found\n", string(body)) assert.Equal(t, "404 page not found\n", string(body))
} })
{ t.Run("bad url", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -203,7 +204,7 @@ func TestHttp_DoWithAssets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, string(body), "Server error") assert.Contains(t, string(body), "Server error")
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type")) assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
} })
} }
func TestHttp_DoWithAssetsCustom404(t *testing.T) { func TestHttp_DoWithAssetsCustom404(t *testing.T) {
@ -243,7 +244,7 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
client := http.Client{} client := http.Client{}
{ t.Run("api call, found", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
@ -257,9 +258,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "response /567/something", string(body)) assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
} })
{ t.Run("static call, found", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -272,9 +273,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
} })
{ t.Run("static call, not found", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -284,9 +285,9 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body)) assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body))
t.Logf("%+v", resp.Header) t.Logf("%+v", resp.Header)
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
} })
{ t.Run("another static call, not found", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad2.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad2.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -296,7 +297,7 @@ func TestHttp_DoWithAssetsCustom404(t *testing.T) {
assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body)) assert.Equal(t, "not found! blah blah blah\nthere is no spoon", string(body))
t.Logf("%+v", resp.Header) t.Logf("%+v", resp.Header)
assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
} })
} }
func TestHttp_DoWithSpaAssets(t *testing.T) { func TestHttp_DoWithSpaAssets(t *testing.T) {
@ -336,7 +337,7 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
client := http.Client{} client := http.Client{}
{ t.Run("api call, good", func(t *testing.T) {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody) req, err := http.NewRequest("GET", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", http.NoBody)
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
@ -350,9 +351,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
assert.Equal(t, "response /567/something", string(body)) assert.Equal(t, "response /567/something", string(body))
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "v1", resp.Header.Get("h1")) assert.Equal(t, "v1", resp.Header.Get("h1"))
} })
{ t.Run("static call, good", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/1.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -365,9 +366,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
} })
{ t.Run("static call, not found server index", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/static/bad.html")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -380,9 +381,9 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
assert.Equal(t, "", resp.Header.Get("App-Method")) assert.Equal(t, "", resp.Header.Get("App-Method"))
assert.Equal(t, "", resp.Header.Get("h1")) assert.Equal(t, "", resp.Header.Get("h1"))
assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control")) assert.Equal(t, "public, max-age=43200", resp.Header.Get("Cache-Control"))
} })
{ t.Run("static call, bad url", func(t *testing.T) {
resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad") resp, err := client.Get("http://localhost:" + strconv.Itoa(port) + "/svcbad")
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
@ -391,7 +392,7 @@ func TestHttp_DoWithSpaAssets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, string(body), "Server error") assert.Contains(t, string(body), "Server error")
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type")) assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
} })
} }
func TestHttp_DoWithAssetRules(t *testing.T) { func TestHttp_DoWithAssetRules(t *testing.T) {
@ -715,16 +716,16 @@ func TestHttp_withBasicAuth(t *testing.T) {
client := http.Client{} client := http.Client{}
{ t.Run("no auth", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
} })
{ t.Run("bad auth", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test", "badpasswd") req.SetBasicAuth("test", "badpasswd")
require.NoError(t, err) require.NoError(t, err)
@ -732,8 +733,9 @@ func TestHttp_withBasicAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
} })
{
t.Run("good auth", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test", "passwd") req.SetBasicAuth("test", "passwd")
require.NoError(t, err) require.NoError(t, err)
@ -741,8 +743,9 @@ func TestHttp_withBasicAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
} })
{
t.Run("good auth 2", func(t *testing.T) {
req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg")) req, err := http.NewRequest("POST", "http://127.0.0.1:"+strconv.Itoa(port)+"/api/something", bytes.NewBufferString("abcdefg"))
req.SetBasicAuth("test2", "passwd2") req.SetBasicAuth("test2", "passwd2")
require.NoError(t, err) require.NoError(t, err)
@ -750,7 +753,7 @@ func TestHttp_withBasicAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
} })
} }
func TestHttp_toHttp(t *testing.T) { func TestHttp_toHttp(t *testing.T) {
@ -766,9 +769,9 @@ func TestHttp_toHttp(t *testing.T) {
} }
h := Http{} h := Http{}
for i, tt := range tbl { for _, tt := range tbl {
tt := tt tt := tt
t.Run(strconv.Itoa(i), func(t *testing.T) { t.Run(tt.addr, func(t *testing.T) {
assert.Equal(t, tt.res, h.toHTTP(tt.addr, tt.port)) assert.Equal(t, tt.res, h.toHTTP(tt.addr, tt.port))
}) })
} }
@ -791,8 +794,8 @@ func TestHttp_isAssetRequest(t *testing.T) {
{"/static/", "/tmp", "", false}, {"/static/", "/tmp", "", false},
} }
for i, tt := range tbl { for _, tt := range tbl {
t.Run(strconv.Itoa(i), func(t *testing.T) { t.Run(tt.req, func(t *testing.T) {
h := Http{AssetsLocation: tt.assetsLocation, AssetsWebRoot: tt.assetsWebRoot} h := Http{AssetsLocation: tt.assetsLocation, AssetsWebRoot: tt.assetsWebRoot}
r, err := http.NewRequest("GET", tt.req, http.NoBody) r, err := http.NewRequest("GET", tt.req, http.NoBody)
require.NoError(t, err) require.NoError(t, err)
@ -803,56 +806,61 @@ func TestHttp_isAssetRequest(t *testing.T) {
} }
func TestHttp_matchHandler(t *testing.T) { func TestHttp_matchHandler(t *testing.T) {
tbl := []struct { tbl := []struct {
name string
matches discovery.Matches matches discovery.Matches
res string res string
ok bool ok bool
}{ }{
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ name: "all alive destinations",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: true}, {Destination: "dest1", Alive: true},
{Destination: "dest2", Alive: true}, {Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: true}, {Destination: "dest3", Alive: true},
}}, }},
"dest1", true, res: "dest1", ok: true,
}, },
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ name: "second alive destination",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false}, {Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: true}, {Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: false}, {Destination: "dest3", Alive: false},
}}, }},
"dest2", true, res: "dest2", ok: true,
}, },
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ name: "one dead destination",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false}, {Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: true}, {Destination: "dest2", Alive: true},
{Destination: "dest3", Alive: true}, {Destination: "dest3", Alive: true},
}}, }},
"dest2", true, res: "dest2", ok: true,
}, },
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ name: "last alive destination",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false}, {Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: false}, {Destination: "dest2", Alive: false},
{Destination: "dest3", Alive: true}, {Destination: "dest3", Alive: true},
}}, }},
"dest3", true, res: "dest3", ok: true,
}, },
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{ name: "all dead destinations",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{
{Destination: "dest1", Alive: false}, {Destination: "dest1", Alive: false},
{Destination: "dest2", Alive: false}, {Destination: "dest2", Alive: false},
{Destination: "dest3", Alive: false}, {Destination: "dest3", Alive: false},
}}, }},
"", false, res: "", ok: false,
}, },
{ {
discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{}}, "", false, name: "no destinations",
matches: discovery.Matches{MatchType: discovery.MTProxy, Routes: []discovery.MatchedRoute{}}, res: "", ok: false,
}, },
} }
@ -864,9 +872,8 @@ func TestHttp_matchHandler(t *testing.T) {
} }
client := http.Client{} client := http.Client{}
for i, tt := range tbl { for _, tt := range tbl {
t.Run(strconv.Itoa(i), func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }} h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }}
handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %+v", r) t.Logf("req: %+v", r)
@ -893,7 +900,6 @@ func TestHttp_matchHandler(t *testing.T) {
} }
func TestHttp_discoveredServers(t *testing.T) { func TestHttp_discoveredServers(t *testing.T) {
calls := 0 calls := 0
m := &MatcherMock{ServersFunc: func() []string { m := &MatcherMock{ServersFunc: func() []string {
defer func() { calls++ }() defer func() { calls++ }()