From 282b4b268cc6581a741a5dbbdf1c3d4d253f66fe Mon Sep 17 00:00:00 2001 From: Umputun Date: Fri, 28 May 2021 16:11:16 -0500 Subject: [PATCH] add lb selector --- app/main.go | 15 +++++++++++++++ app/proxy/proxy.go | 18 +++++++++++------- app/proxy/proxy_test.go | 4 ++-- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/app/main.go b/app/main.go index 57d44b0..c680945 100644 --- a/app/main.go +++ b/app/main.go @@ -7,6 +7,7 @@ import ( "io" "io/ioutil" "math" + "math/rand" "net/http" "os" "os/signal" @@ -31,6 +32,7 @@ var opts struct { MaxSize string `short:"m" long:"max" env:"MAX_SIZE" default:"64K" description:"max request size"` GzipEnabled bool `short:"g" long:"gzip" env:"GZIP" description:"enable gz compression"` ProxyHeaders []string `short:"x" long:"header" env:"HEADER" description:"proxy headers" env-delim:","` + LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` //nolint SSL struct { Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` //nolint @@ -239,6 +241,7 @@ func run() error { AccessLog: accessLog, StdOutEnabled: opts.Logger.StdOut, Signature: opts.Signature, + LBSelector: makeLBSelector(), Timeouts: proxy.Timeouts{ ReadHeader: opts.Timeouts.ReadHeader, Write: opts.Timeouts.Write, @@ -308,6 +311,18 @@ func makeProviders() ([]discovery.Provider, error) { return res, nil } +func makeLBSelector() func(len int) int { + switch opts.LBType { + case "random": + rand.Seed(time.Now().UnixNano()) + return rand.Intn + case "failover": + return func(int) int { return 0 } // dead server won't be in the list, we can safely pick the first one + default: + return func(int) int { return 0 } + } +} + func makeSSLConfig() (config proxy.SSLConfig, err error) { switch opts.SSL.Type { case "none": diff --git a/app/proxy/proxy.go b/app/proxy/proxy.go index a42bec5..5300155 100644 --- a/app/proxy/proxy.go +++ b/app/proxy/proxy.go @@ -41,6 +41,7 @@ type Http struct { // nolint golint CacheControl MiddlewareProvider Metrics MiddlewareProvider Reporter Reporter + LBSelector func(len int) int } // Matcher source info (server and route) to the destination url @@ -84,6 +85,10 @@ func (h *Http) Run(ctx context.Context) error { log.Printf("[DEBUG] assets file server enabled for %s, webroot %s", h.AssetsLocation, h.AssetsWebRoot) } + if h.LBSelector == nil { + h.LBSelector = rand.Intn + } + var httpServer, httpsServer *http.Server go func() { @@ -113,8 +118,6 @@ func (h *Http) Run(ctx context.Context) error { h.gzipHandler(), ) - rand.Seed(time.Now().UnixNano()) - if len(h.SSLConfig.FQDNs) == 0 && h.SSLConfig.SSLMode == SSLAuto { // discovery async and may happen not right away. Try to get servers for some time for i := 0; i < 100; i++ { @@ -206,8 +209,8 @@ func (h *Http) proxyHandler() http.HandlerFunc { server = strings.Split(r.Host, ":")[0] } matches := h.Match(server, r.URL.Path) // get all matches for the server:path pair - u, ok := h.getMatch(matches, rand.Intn) - if !ok { // no route match + u, ok := h.getMatch(matches) // pick a single match from alive only, uses LBSelector as the strategy + if !ok { // no route match if h.isAssetRequest(r) { assetsHandler.ServeHTTP(w, r) return @@ -244,24 +247,25 @@ func (h *Http) proxyHandler() http.HandlerFunc { } } -func (h *Http) getMatch(mm discovery.Matches, picker func(len int) int) (u string, ok bool) { +func (h *Http) getMatch(mm discovery.Matches) (u string, ok bool) { if len(mm.Routes) == 0 { return "", false } - var urls []string + var urls []string // alive destinations only for _, m := range mm.Routes { if m.Alive { urls = append(urls, m.Destination) } } + switch len(urls) { case 0: return "", false case 1: return urls[0], true default: - return urls[picker(len(urls))], true + return urls[h.LBSelector(len(urls))], true } } diff --git a/app/proxy/proxy_test.go b/app/proxy/proxy_test.go index 4f9b50d..cfe5da3 100644 --- a/app/proxy/proxy_test.go +++ b/app/proxy/proxy_test.go @@ -408,10 +408,10 @@ func TestHttp_getMatch(t *testing.T) { }, } - h := Http{} + h := Http{LBSelector: func(len int) int { return 0 }} for i, tt := range tbl { t.Run(strconv.Itoa(i), func(t *testing.T) { - res, ok := h.getMatch(tt.matches, func(len int) int { return 0 }) + res, ok := h.getMatch(tt.matches) require.Equal(t, tt.ok, ok) assert.Equal(t, tt.res, res) })