diff --git a/context.go b/context.go index 44822f2f..242eec26 100644 --- a/context.go +++ b/context.go @@ -31,6 +31,9 @@ type ( // IsTLS returns true if HTTP connection is TLS otherwise false. IsTLS() bool + // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. + IsWebSocket() bool + // Scheme returns the HTTP protocol scheme, `http` or `https`. Scheme() string @@ -219,6 +222,11 @@ func (c *context) IsTLS() bool { return c.request.TLS != nil } +func (c *context) IsWebSocket() bool { + upgrade := c.request.Header.Get(HeaderUpgrade) + return upgrade == "websocket" || upgrade == "Websocket" +} + func (c *context) Scheme() string { // Can't use `r.Request.URL.Scheme` // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 @@ -227,10 +235,16 @@ func (c *context) Scheme() string { } if scheme := c.request.Header.Get(HeaderXForwardedProto); scheme != "" { return scheme - } + } + if scheme := c.request.Header.Get(HeaderXForwardedProtocol); scheme != "" { + return scheme + } if ssl := c.request.Header.Get(HeaderXForwardedSsl); ssl == "on" { return "https" } + if scheme := c.request.Header.Get(HeaderXUrlScheme); scheme != "" { + return scheme + } return "http" } diff --git a/echo.go b/echo.go index 60733519..baa92fde 100644 --- a/echo.go +++ b/echo.go @@ -165,30 +165,34 @@ const ( // Headers const ( - HeaderAccept = "Accept" - HeaderAcceptEncoding = "Accept-Encoding" - HeaderAllow = "Allow" - HeaderAuthorization = "Authorization" - HeaderContentDisposition = "Content-Disposition" - HeaderContentEncoding = "Content-Encoding" - HeaderContentLength = "Content-Length" - HeaderContentType = "Content-Type" - HeaderCookie = "Cookie" - HeaderSetCookie = "Set-Cookie" - HeaderIfModifiedSince = "If-Modified-Since" - HeaderLastModified = "Last-Modified" - HeaderLocation = "Location" - HeaderUpgrade = "Upgrade" - HeaderVary = "Vary" - HeaderWWWAuthenticate = "WWW-Authenticate" - HeaderXForwardedProto = "X-Forwarded-Proto" - HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" - HeaderXForwardedFor = "X-Forwarded-For" - HeaderXForwardedSsl = "X-Forwarded-Ssl" - HeaderXRealIP = "X-Real-IP" - HeaderXRequestID = "X-Request-ID" - HeaderServer = "Server" - HeaderOrigin = "Origin" + HeaderAccept = "Accept" + HeaderAcceptEncoding = "Accept-Encoding" + HeaderAllow = "Allow" + HeaderAuthorization = "Authorization" + HeaderContentDisposition = "Content-Disposition" + HeaderContentEncoding = "Content-Encoding" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderCookie = "Cookie" + HeaderSetCookie = "Set-Cookie" + HeaderIfModifiedSince = "If-Modified-Since" + HeaderLastModified = "Last-Modified" + HeaderLocation = "Location" + HeaderUpgrade = "Upgrade" + HeaderVary = "Vary" + HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXForwardedProto = "X-Forwarded-Proto" + HeaderXForwardedProtocol = "X-Forwarded-Protocol" + HeaderXForwardedSsl = "X-Forwarded-Ssl" + HeaderXUrlScheme = "X-Url-Scheme" + HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" + HeaderXRealIP = "X-Real-IP" + HeaderXRequestID = "X-Request-ID" + HeaderServer = "Server" + HeaderOrigin = "Origin" + + // Access control HeaderAccessControlRequestMethod = "Access-Control-Request-Method" HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" diff --git a/middleware/proxy.go b/middleware/proxy.go index 9277e30c..7eb24abf 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -15,6 +15,8 @@ import ( "github.com/labstack/echo" ) +// TODO: Handle TLS proxy + type ( // ProxyConfig defines the config for Proxy middleware. ProxyConfig struct { @@ -63,17 +65,16 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { c.Error(errors.New("proxy raw, not a hijacker")) return } - in, _, err := h.Hijack() if err != nil { - c.Error(fmt.Errorf("proxy raw hijack error=%v, url=%s", r.URL, err)) + c.Error(fmt.Errorf("proxy raw, hijack error=%v, url=%s", r.URL, err)) return } defer in.Close() out, err := net.Dial("tcp", t.URL.Host) if err != nil { - he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw dial error=%v, url=%s", r.URL, err)) + he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", r.URL, err)) c.Error(he) return } @@ -81,7 +82,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { err = r.Write(out) if err != nil { - he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw request copy error=%v, url=%s", r.URL, err)) + he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request copy error=%v, url=%s", r.URL, err)) c.Error(he) return } @@ -96,7 +97,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { go cp(in, out) err = <-errc if err != nil && err != io.EOF { - c.Logger().Errorf("proxy raw error=%v, url=%s", r.URL, err) + c.Logger().Errorf("proxy raw, error=%v, url=%s", r.URL, err) } }) } @@ -131,18 +132,26 @@ func Proxy(config ProxyConfig) echo.MiddlewareFunc { return func(c echo.Context) (err error) { req := c.Request() res := c.Response() - t := config.Balancer.Next() + tgt := config.Balancer.Next() + + // Fix header + if req.Header.Get(echo.HeaderXRealIP) == "" { + req.Header.Set(echo.HeaderXRealIP, c.RealIP()) + } + if req.Header.Get(echo.HeaderXForwardedProto) == "" { + req.Header.Set(echo.HeaderXForwardedProto, c.Scheme()) + } + if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy. + req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) + } // Proxy - upgrade := req.Header.Get(echo.HeaderUpgrade) - accept := req.Header.Get(echo.HeaderAccept) - switch { - case upgrade == "websocket" || upgrade == "Websocket": - proxyRaw(t, c).ServeHTTP(res, req) - case accept == "text/event-stream": + case c.IsWebSocket(): + proxyRaw(tgt, c).ServeHTTP(res, req) + case req.Header.Get(echo.HeaderAccept) == "text/event-stream": default: - proxyHTTP(t).ServeHTTP(res, req) + proxyHTTP(tgt).ServeHTTP(res, req) } return