mirror of
https://github.com/labstack/echo.git
synced 2025-11-29 22:48:07 +02:00
Add support for TLS WebSocket proxy (#2762)
* Add support for TLS WebSocket proxy * support tls to non-tls and non-tls to tls websocket proxy
This commit is contained in:
@@ -5,6 +5,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
@@ -130,7 +131,21 @@ var DefaultProxyConfig = ProxyConfig{
|
||||
ContextKey: "target",
|
||||
}
|
||||
|
||||
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
||||
func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
|
||||
var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
if transport, ok := config.Transport.(*http.Transport); ok {
|
||||
if transport.TLSClientConfig != nil {
|
||||
d := tls.Dialer{
|
||||
Config: transport.TLSClientConfig,
|
||||
}
|
||||
dialFunc = d.DialContext
|
||||
}
|
||||
}
|
||||
if dialFunc == nil {
|
||||
var d net.Dialer
|
||||
dialFunc = d.DialContext
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
in, _, err := c.Response().Hijack()
|
||||
if err != nil {
|
||||
@@ -138,13 +153,11 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
||||
return
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
out, err := net.Dial("tcp", t.URL.Host)
|
||||
out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host)
|
||||
if err != nil {
|
||||
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
|
||||
return
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Write header
|
||||
err = r.Write(out)
|
||||
@@ -365,7 +378,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
// Proxy
|
||||
switch {
|
||||
case c.IsWebSocket():
|
||||
proxyRaw(tgt, c).ServeHTTP(res, req)
|
||||
proxyRaw(tgt, c, config).ServeHTTP(res, req)
|
||||
default: // even SSE requests
|
||||
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
// Assert expected with url.EscapedPath method to obtain the path.
|
||||
@@ -810,3 +812,231 @@ func TestModifyResponseUseContext(t *testing.T) {
|
||||
assert.Equal(t, "OK", rec.Body.String())
|
||||
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
|
||||
}
|
||||
|
||||
func createSimpleWebSocketServer(serveTLS bool) *httptest.Server {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
wsHandler := func(conn *websocket.Conn) {
|
||||
defer conn.Close()
|
||||
for {
|
||||
var msg string
|
||||
err := websocket.Message.Receive(conn, &msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// message back to the client
|
||||
websocket.Message.Send(conn, msg)
|
||||
}
|
||||
}
|
||||
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
|
||||
})
|
||||
if serveTLS {
|
||||
return httptest.NewTLSServer(handler)
|
||||
}
|
||||
return httptest.NewServer(handler)
|
||||
}
|
||||
|
||||
func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server {
|
||||
e := echo.New()
|
||||
|
||||
if toTLS {
|
||||
// proxy to tls target
|
||||
tgtURL, _ := url.Parse(srv.URL)
|
||||
tgtURL.Scheme = "wss"
|
||||
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
|
||||
|
||||
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatal("Default transport is not of type *http.Transport")
|
||||
}
|
||||
transport := defaultTransport.Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))
|
||||
} else {
|
||||
// proxy to non-TLS target
|
||||
tgtURL, _ := url.Parse(srv.URL)
|
||||
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
|
||||
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
|
||||
}
|
||||
|
||||
if serveTLS {
|
||||
// serve proxy server with TLS
|
||||
ts := httptest.NewTLSServer(e)
|
||||
return ts
|
||||
}
|
||||
// serve proxy server without TLS
|
||||
ts := httptest.NewServer(e)
|
||||
return ts
|
||||
}
|
||||
|
||||
// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection.
|
||||
func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) {
|
||||
/*
|
||||
Arrange
|
||||
*/
|
||||
// Create a WebSocket test server (non-TLS)
|
||||
srv := createSimpleWebSocketServer(false)
|
||||
defer srv.Close()
|
||||
|
||||
// create proxy server (non-TLS to non-TLS)
|
||||
ts := createSimpleProxyServer(t, srv, false, false)
|
||||
defer ts.Close()
|
||||
|
||||
tsURL, _ := url.Parse(ts.URL)
|
||||
tsURL.Scheme = "ws"
|
||||
tsURL.Path = "/"
|
||||
|
||||
/*
|
||||
Act
|
||||
*/
|
||||
|
||||
// Connect to the proxy WebSocket
|
||||
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
|
||||
assert.NoError(t, err)
|
||||
defer wsConn.Close()
|
||||
|
||||
// Send message
|
||||
sendMsg := "Hello, Non TLS WebSocket!"
|
||||
err = websocket.Message.Send(wsConn, sendMsg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
/*
|
||||
Assert
|
||||
*/
|
||||
// Read response
|
||||
var recvMsg string
|
||||
err = websocket.Message.Receive(wsConn, &recvMsg)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sendMsg, recvMsg)
|
||||
}
|
||||
|
||||
// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection.
|
||||
func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) {
|
||||
/*
|
||||
Arrange
|
||||
*/
|
||||
// Create a WebSocket test server (TLS)
|
||||
srv := createSimpleWebSocketServer(true)
|
||||
defer srv.Close()
|
||||
|
||||
// create proxy server (TLS to TLS)
|
||||
ts := createSimpleProxyServer(t, srv, true, true)
|
||||
defer ts.Close()
|
||||
|
||||
tsURL, _ := url.Parse(ts.URL)
|
||||
tsURL.Scheme = "wss"
|
||||
tsURL.Path = "/"
|
||||
|
||||
/*
|
||||
Act
|
||||
*/
|
||||
origin, err := url.Parse(ts.URL)
|
||||
assert.NoError(t, err)
|
||||
config := &websocket.Config{
|
||||
Location: tsURL,
|
||||
Origin: origin,
|
||||
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
|
||||
Version: websocket.ProtocolVersionHybi13,
|
||||
}
|
||||
wsConn, err := websocket.DialConfig(config)
|
||||
assert.NoError(t, err)
|
||||
defer wsConn.Close()
|
||||
|
||||
// Send message
|
||||
sendMsg := "Hello, TLS to TLS WebSocket!"
|
||||
err = websocket.Message.Send(wsConn, sendMsg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
var recvMsg string
|
||||
err = websocket.Message.Receive(wsConn, &recvMsg)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sendMsg, recvMsg)
|
||||
}
|
||||
|
||||
// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection.
|
||||
func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) {
|
||||
/*
|
||||
Arrange
|
||||
*/
|
||||
|
||||
// Create a WebSocket test server (TLS)
|
||||
srv := createSimpleWebSocketServer(true)
|
||||
defer srv.Close()
|
||||
|
||||
// create proxy server (Non-TLS to TLS)
|
||||
ts := createSimpleProxyServer(t, srv, false, true)
|
||||
defer ts.Close()
|
||||
|
||||
tsURL, _ := url.Parse(ts.URL)
|
||||
tsURL.Scheme = "ws"
|
||||
tsURL.Path = "/"
|
||||
|
||||
/*
|
||||
Act
|
||||
*/
|
||||
// Connect to the proxy WebSocket
|
||||
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
|
||||
assert.NoError(t, err)
|
||||
defer wsConn.Close()
|
||||
|
||||
// Send message
|
||||
sendMsg := "Hello, Non TLS to TLS WebSocket!"
|
||||
err = websocket.Message.Send(wsConn, sendMsg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
/*
|
||||
Assert
|
||||
*/
|
||||
// Read response
|
||||
var recvMsg string
|
||||
err = websocket.Message.Receive(wsConn, &recvMsg)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sendMsg, recvMsg)
|
||||
}
|
||||
|
||||
// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination)
|
||||
func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
|
||||
/*
|
||||
Arrange
|
||||
*/
|
||||
|
||||
// Create a WebSocket test server (non-TLS)
|
||||
srv := createSimpleWebSocketServer(false)
|
||||
defer srv.Close()
|
||||
|
||||
// create proxy server (TLS to non-TLS)
|
||||
ts := createSimpleProxyServer(t, srv, true, false)
|
||||
defer ts.Close()
|
||||
|
||||
tsURL, _ := url.Parse(ts.URL)
|
||||
tsURL.Scheme = "wss"
|
||||
tsURL.Path = "/"
|
||||
|
||||
/*
|
||||
Act
|
||||
*/
|
||||
origin, err := url.Parse(ts.URL)
|
||||
assert.NoError(t, err)
|
||||
config := &websocket.Config{
|
||||
Location: tsURL,
|
||||
Origin: origin,
|
||||
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
|
||||
Version: websocket.ProtocolVersionHybi13,
|
||||
}
|
||||
wsConn, err := websocket.DialConfig(config)
|
||||
assert.NoError(t, err)
|
||||
defer wsConn.Close()
|
||||
|
||||
// Send message
|
||||
sendMsg := "Hello, TLS to NoneTLS WebSocket!"
|
||||
err = websocket.Message.Send(wsConn, sendMsg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
var recvMsg string
|
||||
err = websocket.Message.Receive(wsConn, &recvMsg)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sendMsg, recvMsg)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user