diff --git a/Gopkg.lock b/Gopkg.lock index 272aaa52..b75bf95a 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -61,6 +61,12 @@ packages = ["acme","acme/autocert"] revision = "e1a4589e7d3ea14a3352255d04b6f1a418845e5e" +[[projects]] + branch = "master" + name = "golang.org/x/net" + packages = ["websocket"] + revision = "e4fa1c5465ad6111f206fc92186b8c83d64adbe1" + [[projects]] branch = "master" name = "golang.org/x/sys" @@ -70,6 +76,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "5f74a2a2ba5b07475ad0faa1b4c021b973ad40b2ae749e3d94e15fe839bb440e" + inputs-digest = "e6ecdde9a4df2afdf849e47fd5a0680122ec991fbadea0a59e54d36b52e8f681" solver-name = "gps-cdcl" solver-version = 1 diff --git a/middleware/jwt.go b/middleware/jwt.go index b2658739..5d2072e7 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -91,7 +91,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { config.Skipper = DefaultJWTConfig.Skipper } if config.SigningKey == nil { - panic("jwt middleware requires signing key") + panic("echo: jwt middleware requires signing key") } if config.SigningMethod == "" { config.SigningMethod = DefaultJWTConfig.SigningMethod diff --git a/middleware/proxy.go b/middleware/proxy.go new file mode 100644 index 00000000..f3718770 --- /dev/null +++ b/middleware/proxy.go @@ -0,0 +1,163 @@ +package middleware + +import ( + "io" + "math/rand" + "net/http" + "net/http/httputil" + "net/url" + "sync/atomic" + "time" + + "github.com/labstack/echo" + "golang.org/x/net/websocket" +) + +type ( + // ProxyConfig defines the config for Proxy middleware. + ProxyConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Load balancing technique. + // Optional. Default value "random". + // Possible values: + // - "random" + // - "round-robin" + Balance string `json:"balance"` + + // Upstream target URLs + // Required. + Targets []*ProxyTarget `json:"targets"` + + balancer proxyBalancer + } + + // ProxyTarget defines the upstream target. + ProxyTarget struct { + Name string `json:"name,omitempty"` + URL string `json:"url"` + url *url.URL + } + + proxyRandom struct { + targets []*ProxyTarget + random *rand.Rand + } + + proxyRoundRobin struct { + targets []*ProxyTarget + i int32 + } + + proxyBalancer interface { + Next() *ProxyTarget + Length() int + } +) + +func proxyHTTP(u *url.URL, c echo.Context) http.Handler { + return httputil.NewSingleHostReverseProxy(u) +} + +func proxyWS(u *url.URL, c echo.Context) http.Handler { + return websocket.Handler(func(in *websocket.Conn) { + defer in.Close() + + r := in.Request() + t := "ws://" + u.Host + r.RequestURI + out, err := websocket.Dial(t, "", r.Header.Get("Origin")) + if err != nil { + c.Logger().Errorf("ws proxy error, target=%s, err=%v", t, err) + return + } + defer out.Close() + + errc := make(chan error, 2) + cp := func(w io.Writer, r io.Reader) { + _, err := io.Copy(w, r) + errc <- err + } + + go cp(in, out) + go cp(out, in) + err = <-errc + if err != nil && err != io.EOF { + c.Logger().Errorf("ws proxy error, url=%s, err=%v", r.URL, err) + } + }) +} + +func (r *proxyRandom) Next() *ProxyTarget { + return r.targets[r.random.Intn(len(r.targets))] +} + +func (r *proxyRandom) Length() int { + return len(r.targets) +} + +func (r *proxyRoundRobin) Next() *ProxyTarget { + r.i = r.i % int32(len(r.targets)) + atomic.AddInt32(&r.i, 1) + return r.targets[r.i] +} + +func (r *proxyRoundRobin) Length() int { + return len(r.targets) +} + +// Proxy returns an HTTP/WebSocket reverse proxy middleware. +func Proxy(config ProxyConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultLoggerConfig.Skipper + } + if config.Targets == nil || len(config.Targets) == 0 { + panic("echo: proxy middleware requires targets") + } + + // Initialize + for _, t := range config.Targets { + u, err := url.Parse(t.URL) + if err != nil { + panic("echo: proxy target url parsing failed" + err.Error()) + } + t.url = u + } + + // Balancer + switch config.Balance { + case "round-robin": + config.balancer = &proxyRoundRobin{ + targets: config.Targets, + i: -1, + } + default: // random + config.balancer = &proxyRandom{ + targets: config.Targets, + random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), + } + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) (err error) { + req := c.Request() + res := c.Response() + t := config.balancer.Next().url + + // Tell upstream that the incoming request is HTTPS + if c.IsTLS() { + req.Header.Set(echo.HeaderXForwardedProto, "https") + } + + // Proxy + if req.Header.Get(echo.HeaderUpgrade) == "websocket" { + proxyWS(t, c).ServeHTTP(res, req) + } else { + proxyHTTP(t, c).ServeHTTP(res, req) + } + + return + } + } +} diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go new file mode 100644 index 00000000..65ac5715 --- /dev/null +++ b/middleware/proxy_test.go @@ -0,0 +1,83 @@ +package middleware + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo" + "github.com/stretchr/testify/assert" +) + +type ( + closeNotifyRecorder struct { + *httptest.ResponseRecorder + closed chan bool + } +) + +func newCloseNotifyRecorder() *closeNotifyRecorder { + return &closeNotifyRecorder{ + httptest.NewRecorder(), + make(chan bool, 1), + } +} + +func (c *closeNotifyRecorder) close() { + c.closed <- true +} + +func (c *closeNotifyRecorder) CloseNotify() <-chan bool { + return c.closed +} + +func TestProxy(t *testing.T) { + // Setup + t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "target 1") + })) + defer t1.Close() + t2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "target 2") + })) + defer t2.Close() + config := ProxyConfig{ + Targets: []*ProxyTarget{ + &ProxyTarget{ + URL: t1.URL, + }, + &ProxyTarget{ + URL: t2.URL, + }, + }, + } + + // Random + e := echo.New() + e.Use(Proxy(config)) + req := httptest.NewRequest(echo.GET, "/", nil) + rec := newCloseNotifyRecorder() + e.ServeHTTP(rec, req) + body := rec.Body.String() + targets := map[string]bool{ + "target 1": true, + "target 2": true, + } + assert.Condition(t, func() bool { + return targets[body] + }) + + // Round-robin + config.Balance = "round-robin" + e = echo.New() + e.Use(Proxy(config)) + rec = newCloseNotifyRecorder() + e.ServeHTTP(rec, req) + body = rec.Body.String() + assert.Equal(t, "target 1", body) + rec = newCloseNotifyRecorder() + e.ServeHTTP(rec, req) + body = rec.Body.String() + assert.Equal(t, "target 2", body) +} diff --git a/router.go b/router.go index 33da20ed..2ef904e0 100644 --- a/router.go +++ b/router.go @@ -101,7 +101,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string cn := r.tree // Current node as root if cn == nil { - panic("echo ⇛ invalid method") + panic("echo: invalid method") } search := path