1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-15 00:15:00 +02:00

Merge websocket proxy feature from openshift/oauth-proxy. Original author: Hiram Chirino <hiram@hiramchirino.com>

This commit is contained in:
Adam Szalkowski
2019-03-08 09:15:21 +01:00
parent 21c9d38ada
commit c7193b4085
7 changed files with 145 additions and 26 deletions

View File

@ -19,6 +19,7 @@ import (
"github.com/pusher/oauth2_proxy/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
)
func init() {
@ -26,6 +27,83 @@ func init() {
}
type WebSocketOrRestHandler struct {
restHandler http.Handler
wsHandler http.Handler
}
func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" {
h.wsHandler.ServeHTTP(w, r)
} else {
h.restHandler.ServeHTTP(w, r)
}
}
func TestWebSocketProxy(t *testing.T) {
handler := WebSocketOrRestHandler{
restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
hostname, _, _ := net.SplitHostPort(r.Host)
w.Write([]byte(hostname))
}),
wsHandler: websocket.Handler(func(ws *websocket.Conn) {
defer ws.Close()
var data []byte
err := websocket.Message.Receive(ws, &data)
if err != nil {
t.Fatalf("err %s", err)
return
}
err = websocket.Message.Send(ws, data)
if err != nil {
t.Fatalf("err %s", err)
}
return
}),
}
backend := httptest.NewServer(&handler)
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
options := NewOptions()
var auth hmacauth.HmacAuth
options.PassHostHeader = true
proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, options, auth)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
frontendURL, _ := url.Parse(frontend.URL)
frontendWSURL := "ws://" + frontendURL.Host + "/"
ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/")
if err != nil {
t.Fatalf("err %s", err)
}
request := []byte("hello, world!")
err = websocket.Message.Send(ws, request)
if err != nil {
t.Fatalf("err %s", err)
}
var response = make([]byte, 1024)
websocket.Message.Receive(ws, &response)
if err != nil {
t.Fatalf("err %s", err)
}
if g, e := string(request), string(response); g != e {
t.Errorf("got body %q; expected %q", g, e)
}
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
backendHostname, _, _ := net.SplitHostPort(backendURL.Host)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
func TestNewReverseProxy(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)