1
0
mirror of https://github.com/labstack/echo.git synced 2025-11-29 22:48:07 +02:00

Add support for TLS WebSocket proxy (#2762)

* Add support for TLS WebSocket proxy

* support tls to non-tls and non-tls to tls websocket proxy
This commit is contained in:
t-ibayashi-safie
2025-04-04 17:01:42 +09:00
committed by GitHub
parent c44f6283f0
commit de44c53a5b
2 changed files with 248 additions and 5 deletions

View File

@@ -5,6 +5,7 @@ package middleware
import (
"context"
"crypto/tls"
"fmt"
"io"
"math/rand"
@@ -130,7 +131,21 @@ var DefaultProxyConfig = ProxyConfig{
ContextKey: "target",
}
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
if transport, ok := config.Transport.(*http.Transport); ok {
if transport.TLSClientConfig != nil {
d := tls.Dialer{
Config: transport.TLSClientConfig,
}
dialFunc = d.DialContext
}
}
if dialFunc == nil {
var d net.Dialer
dialFunc = d.DialContext
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := c.Response().Hijack()
if err != nil {
@@ -138,13 +153,11 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
return
}
defer in.Close()
out, err := net.Dial("tcp", t.URL.Host)
out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
return
}
defer out.Close()
// Write header
err = r.Write(out)
@@ -365,7 +378,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req)
proxyRaw(tgt, c, config).ServeHTTP(res, req)
default: // even SSE requests
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
}

View File

@@ -6,6 +6,7 @@ package middleware
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
@@ -20,6 +21,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"golang.org/x/net/websocket"
)
// Assert expected with url.EscapedPath method to obtain the path.
@@ -810,3 +812,231 @@ func TestModifyResponseUseContext(t *testing.T) {
assert.Equal(t, "OK", rec.Body.String())
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
}
func createSimpleWebSocketServer(serveTLS bool) *httptest.Server {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsHandler := func(conn *websocket.Conn) {
defer conn.Close()
for {
var msg string
err := websocket.Message.Receive(conn, &msg)
if err != nil {
return
}
// message back to the client
websocket.Message.Send(conn, msg)
}
}
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
})
if serveTLS {
return httptest.NewTLSServer(handler)
}
return httptest.NewServer(handler)
}
func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server {
e := echo.New()
if toTLS {
// proxy to tls target
tgtURL, _ := url.Parse(srv.URL)
tgtURL.Scheme = "wss"
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
t.Fatal("Default transport is not of type *http.Transport")
}
transport := defaultTransport.Clone()
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))
} else {
// proxy to non-TLS target
tgtURL, _ := url.Parse(srv.URL)
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
}
if serveTLS {
// serve proxy server with TLS
ts := httptest.NewTLSServer(e)
return ts
}
// serve proxy server without TLS
ts := httptest.NewServer(e)
return ts
}
// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection.
func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) {
/*
Arrange
*/
// Create a WebSocket test server (non-TLS)
srv := createSimpleWebSocketServer(false)
defer srv.Close()
// create proxy server (non-TLS to non-TLS)
ts := createSimpleProxyServer(t, srv, false, false)
defer ts.Close()
tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "ws"
tsURL.Path = "/"
/*
Act
*/
// Connect to the proxy WebSocket
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
assert.NoError(t, err)
defer wsConn.Close()
// Send message
sendMsg := "Hello, Non TLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)
/*
Assert
*/
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}
// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection.
func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) {
/*
Arrange
*/
// Create a WebSocket test server (TLS)
srv := createSimpleWebSocketServer(true)
defer srv.Close()
// create proxy server (TLS to TLS)
ts := createSimpleProxyServer(t, srv, true, true)
defer ts.Close()
tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "wss"
tsURL.Path = "/"
/*
Act
*/
origin, err := url.Parse(ts.URL)
assert.NoError(t, err)
config := &websocket.Config{
Location: tsURL,
Origin: origin,
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
Version: websocket.ProtocolVersionHybi13,
}
wsConn, err := websocket.DialConfig(config)
assert.NoError(t, err)
defer wsConn.Close()
// Send message
sendMsg := "Hello, TLS to TLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}
// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection.
func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) {
/*
Arrange
*/
// Create a WebSocket test server (TLS)
srv := createSimpleWebSocketServer(true)
defer srv.Close()
// create proxy server (Non-TLS to TLS)
ts := createSimpleProxyServer(t, srv, false, true)
defer ts.Close()
tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "ws"
tsURL.Path = "/"
/*
Act
*/
// Connect to the proxy WebSocket
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
assert.NoError(t, err)
defer wsConn.Close()
// Send message
sendMsg := "Hello, Non TLS to TLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)
/*
Assert
*/
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}
// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination)
func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
/*
Arrange
*/
// Create a WebSocket test server (non-TLS)
srv := createSimpleWebSocketServer(false)
defer srv.Close()
// create proxy server (TLS to non-TLS)
ts := createSimpleProxyServer(t, srv, true, false)
defer ts.Close()
tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "wss"
tsURL.Path = "/"
/*
Act
*/
origin, err := url.Parse(ts.URL)
assert.NoError(t, err)
config := &websocket.Config{
Location: tsURL,
Origin: origin,
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
Version: websocket.ProtocolVersionHybi13,
}
wsConn, err := websocket.DialConfig(config)
assert.NoError(t, err)
defer wsConn.Close()
// Send message
sendMsg := "Hello, TLS to NoneTLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}