mirror of
https://github.com/labstack/echo.git
synced 2025-01-12 01:22:21 +02:00
Exposed proxy balancer interface
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
a8cd0ad133
commit
0898d9e68b
1
echo.go
1
echo.go
@ -165,6 +165,7 @@ const (
|
||||
|
||||
// Headers
|
||||
const (
|
||||
HeaderAccept = "Accept"
|
||||
HeaderAcceptEncoding = "Accept-Encoding"
|
||||
HeaderAllow = "Allow"
|
||||
HeaderAuthorization = "Authorization"
|
||||
|
@ -1,8 +1,11 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
@ -10,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -19,91 +21,95 @@ type (
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Load balancing technique.
|
||||
// Optional. Default value "random".
|
||||
// Possible values:
|
||||
// - "random"
|
||||
// - "round-robin"
|
||||
Balance string `json:"balance"`
|
||||
|
||||
// Upstream target URLs
|
||||
// Balance defines a load balancing technique.
|
||||
// Required.
|
||||
Targets []*ProxyTarget `json:"targets"`
|
||||
|
||||
balancer proxyBalancer
|
||||
// Possible values:
|
||||
// - ProxyRandom
|
||||
// - ProxyRoundRobin
|
||||
Balancer ProxyBalancer
|
||||
}
|
||||
|
||||
// ProxyTarget defines the upstream target.
|
||||
ProxyTarget struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
URL string `json:"url"`
|
||||
url *url.URL
|
||||
URL *url.URL
|
||||
}
|
||||
|
||||
proxyRandom struct {
|
||||
targets []*ProxyTarget
|
||||
RandomBalancer struct {
|
||||
Targets []*ProxyTarget
|
||||
random *rand.Rand
|
||||
}
|
||||
|
||||
proxyRoundRobin struct {
|
||||
targets []*ProxyTarget
|
||||
i int32
|
||||
RoundRobinBalancer struct {
|
||||
Targets []*ProxyTarget
|
||||
i uint32
|
||||
}
|
||||
|
||||
proxyBalancer interface {
|
||||
ProxyBalancer interface {
|
||||
Next() *ProxyTarget
|
||||
Length() int
|
||||
}
|
||||
)
|
||||
|
||||
func proxyHTTP(u *url.URL, c echo.Context) http.Handler {
|
||||
return httputil.NewSingleHostReverseProxy(u)
|
||||
func proxyHTTP(t *ProxyTarget) http.Handler {
|
||||
return httputil.NewSingleHostReverseProxy(t.URL)
|
||||
}
|
||||
|
||||
func proxyWS(u *url.URL, c echo.Context) http.Handler {
|
||||
return websocket.Handler(func(in *websocket.Conn) {
|
||||
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
c.Error(errors.New("proxy raw, not a hijacker"))
|
||||
return
|
||||
}
|
||||
|
||||
in, _, err := h.Hijack()
|
||||
if err != nil {
|
||||
c.Error(fmt.Errorf("proxy raw hijack error=%v, url=%s", r.URL, err))
|
||||
return
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
r := in.Request()
|
||||
t := "ws://" + u.Host + r.RequestURI
|
||||
out, err := websocket.Dial(t, "", r.Header.Get("Origin"))
|
||||
out, err := net.Dial("tcp", t.URL.Host)
|
||||
if err != nil {
|
||||
c.Logger().Errorf("ws proxy error, target=%s, err=%v", t, err)
|
||||
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw dial error=%v, url=%s", r.URL, err))
|
||||
c.Error(he)
|
||||
return
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
err = r.Write(out)
|
||||
if err != nil {
|
||||
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw request copy error=%v, url=%s", r.URL, err))
|
||||
c.Error(he)
|
||||
return
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
cp := func(w io.Writer, r io.Reader) {
|
||||
_, err := io.Copy(w, r)
|
||||
cp := func(dst io.Writer, src io.Reader) {
|
||||
_, err := io.Copy(dst, src)
|
||||
errc <- err
|
||||
}
|
||||
|
||||
go cp(in, out)
|
||||
go cp(out, in)
|
||||
go cp(in, out)
|
||||
err = <-errc
|
||||
if err != nil && err != io.EOF {
|
||||
c.Logger().Errorf("ws proxy error, url=%s, err=%v", r.URL, err)
|
||||
c.Logger().Errorf("proxy raw error=%v, url=%s", r.URL, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (r *proxyRandom) Next() *ProxyTarget {
|
||||
return r.targets[r.random.Intn(len(r.targets))]
|
||||
func (r *RandomBalancer) Next() *ProxyTarget {
|
||||
if r.random == nil {
|
||||
r.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
|
||||
}
|
||||
return r.Targets[r.random.Intn(len(r.Targets))]
|
||||
}
|
||||
|
||||
func (r *proxyRandom) Length() int {
|
||||
return len(r.targets)
|
||||
}
|
||||
|
||||
func (r *proxyRoundRobin) Next() *ProxyTarget {
|
||||
r.i = r.i % int32(len(r.targets))
|
||||
atomic.AddInt32(&r.i, 1)
|
||||
return r.targets[r.i]
|
||||
}
|
||||
|
||||
func (r *proxyRoundRobin) Length() int {
|
||||
return len(r.targets)
|
||||
func (r *RoundRobinBalancer) Next() *ProxyTarget {
|
||||
r.i = r.i % uint32(len(r.Targets))
|
||||
t := r.Targets[r.i]
|
||||
atomic.AddUint32(&r.i, 1)
|
||||
return t
|
||||
}
|
||||
|
||||
// Proxy returns an HTTP/WebSocket reverse proxy middleware.
|
||||
@ -112,49 +118,26 @@ func Proxy(config ProxyConfig) echo.MiddlewareFunc {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultLoggerConfig.Skipper
|
||||
}
|
||||
if config.Targets == nil || len(config.Targets) == 0 {
|
||||
panic("echo: proxy middleware requires targets")
|
||||
}
|
||||
|
||||
// Initialize
|
||||
for _, t := range config.Targets {
|
||||
u, err := url.Parse(t.URL)
|
||||
if err != nil {
|
||||
panic("echo: proxy target url parsing failed" + err.Error())
|
||||
}
|
||||
t.url = u
|
||||
}
|
||||
|
||||
// Balancer
|
||||
switch config.Balance {
|
||||
case "round-robin":
|
||||
config.balancer = &proxyRoundRobin{
|
||||
targets: config.Targets,
|
||||
i: -1,
|
||||
}
|
||||
default: // random
|
||||
config.balancer = &proxyRandom{
|
||||
targets: config.Targets,
|
||||
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
|
||||
}
|
||||
if config.Balancer == nil {
|
||||
panic("echo: proxy middleware requires balancer")
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) (err error) {
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
t := config.balancer.Next().url
|
||||
|
||||
// Tell upstream that the incoming request is HTTPS
|
||||
if c.IsTLS() {
|
||||
req.Header.Set(echo.HeaderXForwardedProto, "https")
|
||||
}
|
||||
t := config.Balancer.Next()
|
||||
|
||||
// Proxy
|
||||
if req.Header.Get(echo.HeaderUpgrade) == "websocket" {
|
||||
proxyWS(t, c).ServeHTTP(res, req)
|
||||
} else {
|
||||
proxyHTTP(t, c).ServeHTTP(res, req)
|
||||
upgrade := req.Header.Get(echo.HeaderUpgrade)
|
||||
accept := req.Header.Get(echo.HeaderAccept)
|
||||
|
||||
switch {
|
||||
case upgrade == "websocket" || upgrade == "Websocket":
|
||||
proxyRaw(t, c).ServeHTTP(res, req)
|
||||
case accept == "text/event-stream":
|
||||
default:
|
||||
proxyHTTP(t).ServeHTTP(res, req)
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"net/url"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@ -38,18 +40,24 @@ func TestProxy(t *testing.T) {
|
||||
fmt.Fprint(w, "target 1")
|
||||
}))
|
||||
defer t1.Close()
|
||||
url1, _ := url.Parse(t1.URL)
|
||||
t2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, "target 2")
|
||||
}))
|
||||
defer t2.Close()
|
||||
url2, _ := url.Parse(t2.URL)
|
||||
|
||||
targets := []*ProxyTarget{
|
||||
&ProxyTarget{
|
||||
URL: url1,
|
||||
},
|
||||
&ProxyTarget{
|
||||
URL: url2,
|
||||
},
|
||||
}
|
||||
config := ProxyConfig{
|
||||
Targets: []*ProxyTarget{
|
||||
&ProxyTarget{
|
||||
URL: t1.URL,
|
||||
},
|
||||
&ProxyTarget{
|
||||
URL: t2.URL,
|
||||
},
|
||||
Balancer: &RandomBalancer{
|
||||
Targets: targets,
|
||||
},
|
||||
}
|
||||
|
||||
@ -60,16 +68,18 @@ func TestProxy(t *testing.T) {
|
||||
rec := newCloseNotifyRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
body := rec.Body.String()
|
||||
targets := map[string]bool{
|
||||
expected := map[string]bool{
|
||||
"target 1": true,
|
||||
"target 2": true,
|
||||
}
|
||||
assert.Condition(t, func() bool {
|
||||
return targets[body]
|
||||
return expected[body]
|
||||
})
|
||||
|
||||
// Round-robin
|
||||
config.Balance = "round-robin"
|
||||
config.Balancer = &RoundRobinBalancer{
|
||||
Targets: targets,
|
||||
}
|
||||
e = echo.New()
|
||||
e.Use(Proxy(config))
|
||||
rec = newCloseNotifyRecorder()
|
||||
|
Loading…
Reference in New Issue
Block a user