mirror of
https://github.com/umputun/reproxy.git
synced 2025-02-16 18:34:30 +02:00
make LBSelector interface and implement all the current methods plus roundrobin
This commit is contained in:
parent
8bde167226
commit
fa23778d42
@ -364,7 +364,7 @@ This is the list of all options supporting multiple elements:
|
||||
-x, --header= outgoing proxy headers to add [$HEADER]
|
||||
--drop-header= incoming headers to drop [$DROP_HEADERS]
|
||||
--basic-htpasswd= htpasswd file for basic auth [$BASIC_HTPASSWD]
|
||||
--lb-type=[random|failover] load balancer type (default: random) [$LB_TYPE]
|
||||
--lb-type=[random|failover|roundrobin] load balancer type (default: random) [$LB_TYPE]
|
||||
--signature enable reproxy signature headers [$SIGNATURE]
|
||||
--remote-lookup-headers enable remote lookup headers [$REMOTE_LOOKUP_HEADERS]
|
||||
--dbg debug mode [$DEBUG]
|
||||
|
13
app/main.go
13
app/main.go
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
@ -36,7 +35,7 @@ var opts struct {
|
||||
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
|
||||
AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"`
|
||||
RemoteLookupHeaders bool `long:"remote-lookup-headers" env:"REMOTE_LOOKUP_HEADERS" description:"enable remote lookup headers"`
|
||||
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint
|
||||
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" choice:"roundrobin" default:"random"` // nolint
|
||||
|
||||
SSL struct {
|
||||
Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` // nolint
|
||||
@ -414,14 +413,16 @@ func makeSSLConfig() (config proxy.SSLConfig, err error) {
|
||||
return config, err
|
||||
}
|
||||
|
||||
func makeLBSelector() func(len int) int {
|
||||
func makeLBSelector() proxy.LBSelector {
|
||||
switch opts.LBType {
|
||||
case "random":
|
||||
return rand.Intn
|
||||
return &proxy.RandomSelector{}
|
||||
case "failover":
|
||||
return func(int) int { return 0 } // dead server won't be in the list, we can safely pick the first one
|
||||
return &proxy.FailoverSelector{}
|
||||
case "roundrobin":
|
||||
return &proxy.RoundRobinSelector{}
|
||||
default:
|
||||
return func(int) int { return 0 }
|
||||
return &proxy.FailoverSelector{}
|
||||
}
|
||||
}
|
||||
|
||||
|
45
app/proxy/lb_selector.go
Normal file
45
app/proxy/lb_selector.go
Normal file
@ -0,0 +1,45 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RoundRobinSelector is a simple round-robin selector, thread-safe
|
||||
type RoundRobinSelector struct {
|
||||
lastSelected int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Select returns next backend index
|
||||
func (r *RoundRobinSelector) Select(n int) int {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
selected := r.lastSelected
|
||||
r.lastSelected = (r.lastSelected + 1) % n
|
||||
return selected
|
||||
}
|
||||
|
||||
// RandomSelector is a random selector, thread-safe
|
||||
type RandomSelector struct{}
|
||||
|
||||
// Select returns random backend index
|
||||
func (r *RandomSelector) Select(n int) int {
|
||||
return rand.Intn(n) //nolint:gosec // no need for crypto/rand here
|
||||
}
|
||||
|
||||
// FailoverSelector is a selector with failover, thread-safe
|
||||
type FailoverSelector struct{}
|
||||
|
||||
// Select returns next backend index
|
||||
func (r *FailoverSelector) Select(_ int) int {
|
||||
return 0 // dead server won't be in the list, we can safely pick the first one
|
||||
}
|
||||
|
||||
// LBSelectorFunc is a functional adapted for LBSelector to select backend from the list
|
||||
type LBSelectorFunc func(n int) int
|
||||
|
||||
// Select returns backend index
|
||||
func (f LBSelectorFunc) Select(n int) int {
|
||||
return f(n)
|
||||
}
|
121
app/proxy/lb_selector_test.go
Normal file
121
app/proxy/lb_selector_test.go
Normal file
@ -0,0 +1,121 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRoundRobinSelector_Select(t *testing.T) {
|
||||
selector := &RoundRobinSelector{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
len int
|
||||
expected int
|
||||
}{
|
||||
{"First call", 3, 0},
|
||||
{"Second call", 3, 1},
|
||||
{"Third call", 3, 2},
|
||||
{"Back to zero", 3, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := selector.Select(tc.len)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelector_SelectConcurrent(t *testing.T) {
|
||||
selector := &RoundRobinSelector{}
|
||||
l := 3
|
||||
numGoroutines := 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
results := &sync.Map{}
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result := selector.Select(l)
|
||||
results.Store(result, struct{}{})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// check that all possible results are present in the map.
|
||||
for i := 0; i < l; i++ {
|
||||
_, ok := results.Load(i)
|
||||
assert.True(t, ok, "expected to find %d in the results", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomSelector_Select(t *testing.T) {
|
||||
selector := &RandomSelector{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
len int
|
||||
}{
|
||||
{"First call", 5},
|
||||
{"Second call", 5},
|
||||
{"Third call", 5},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := selector.Select(tc.len)
|
||||
assert.True(t, result >= 0 && result < tc.len)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverSelector_Select(t *testing.T) {
|
||||
selector := &FailoverSelector{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
len int
|
||||
expected int
|
||||
}{
|
||||
{"First call", 5, 0},
|
||||
{"Second call", 5, 0},
|
||||
{"Third call", 5, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := selector.Select(tc.len)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLBSelectorFunc_Select(t *testing.T) {
|
||||
selector := LBSelectorFunc(func(n int) int {
|
||||
return n - 1 // simple selection logic for testing
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
len int
|
||||
expected int
|
||||
}{
|
||||
{"First call", 5, 4},
|
||||
{"Second call", 3, 2},
|
||||
{"Third call", 1, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := selector.Select(tc.len)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
@ -47,7 +46,7 @@ type Http struct { // nolint golint
|
||||
Metrics MiddlewareProvider
|
||||
PluginConductor MiddlewareProvider
|
||||
Reporter Reporter
|
||||
LBSelector func(len int) int
|
||||
LBSelector LBSelector
|
||||
OnlyFrom *OnlyFrom
|
||||
BasicAuthEnabled bool
|
||||
BasicAuthAllowed []string
|
||||
@ -75,6 +74,11 @@ type Reporter interface {
|
||||
Report(w http.ResponseWriter, code int)
|
||||
}
|
||||
|
||||
// LBSelector defines load balancer strategy
|
||||
type LBSelector interface {
|
||||
Select(len int) int // return index of picked server
|
||||
}
|
||||
|
||||
// Timeouts consolidate timeouts for both server and transport
|
||||
type Timeouts struct {
|
||||
// server timeouts
|
||||
@ -101,7 +105,7 @@ func (h *Http) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if h.LBSelector == nil {
|
||||
h.LBSelector = rand.Intn
|
||||
h.LBSelector = &RandomSelector{}
|
||||
}
|
||||
|
||||
var httpServer, httpsServer *http.Server
|
||||
@ -277,7 +281,7 @@ func (h *Http) proxyHandler() http.HandlerFunc {
|
||||
// and if match found sets it to the request context. Context used by proxy handler as well as by plugin conductor
|
||||
func (h *Http) matchHandler(next http.Handler) http.Handler {
|
||||
|
||||
getMatch := func(mm discovery.Matches, picker func(len int) int) (m discovery.MatchedRoute, ok bool) {
|
||||
getMatch := func(mm discovery.Matches, picker LBSelector) (m discovery.MatchedRoute, ok bool) {
|
||||
if len(mm.Routes) == 0 {
|
||||
return m, false
|
||||
}
|
||||
@ -294,7 +298,7 @@ func (h *Http) matchHandler(next http.Handler) http.Handler {
|
||||
case 1:
|
||||
return matches[0], true
|
||||
default:
|
||||
return matches[picker(len(matches))], true
|
||||
return matches[picker.Select(len(matches))], true
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -874,7 +874,7 @@ func TestHttp_matchHandler(t *testing.T) {
|
||||
client := http.Client{}
|
||||
for _, tt := range tbl {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }}
|
||||
h := Http{Matcher: matcherMock, LBSelector: &FailoverSelector{}}
|
||||
handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("req: %+v", r)
|
||||
t.Logf("dst: %v", r.Context().Value(ctxURL))
|
||||
|
Loading…
x
Reference in New Issue
Block a user