You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-15 00:15:00 +02:00
Configure OAuth2 Proxy to use new upstreams package and LegacyConfig
This commit is contained in:
@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@ -24,11 +23,11 @@ import (
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -44,143 +43,6 @@ func init() {
|
||||
logger.SetFlags(logger.Lshortfile)
|
||||
}
|
||||
|
||||
type WebSocketOrRestHandler struct {
|
||||
restHandler http.Handler
|
||||
wsHandler http.Handler
|
||||
}
|
||||
|
||||
func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") == "websocket" {
|
||||
h.wsHandler.ServeHTTP(w, r)
|
||||
} else {
|
||||
h.restHandler.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketProxy(t *testing.T) {
|
||||
handler := WebSocketOrRestHandler{
|
||||
restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
hostname, _, _ := net.SplitHostPort(r.Host)
|
||||
_, err := w.Write([]byte(hostname))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}),
|
||||
wsHandler: websocket.Handler(func(ws *websocket.Conn) {
|
||||
defer func(t *testing.T) {
|
||||
if err := ws.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}(t)
|
||||
var data []byte
|
||||
err := websocket.Message.Receive(ws, &data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = websocket.Message.Send(ws, data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}),
|
||||
}
|
||||
backend := httptest.NewServer(&handler)
|
||||
t.Cleanup(backend.Close)
|
||||
|
||||
backendURL, _ := url.Parse(backend.URL)
|
||||
|
||||
opts := baseTestOptions()
|
||||
var auth hmacauth.HmacAuth
|
||||
opts.PassHostHeader = true
|
||||
proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, opts, auth)
|
||||
frontend := httptest.NewServer(proxyHandler)
|
||||
t.Cleanup(frontend.Close)
|
||||
|
||||
frontendURL, _ := url.Parse(frontend.URL)
|
||||
frontendWSURL := "ws://" + frontendURL.Host + "/"
|
||||
|
||||
ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
request := []byte("hello, world!")
|
||||
err = websocket.Message.Send(ws, request)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var response = make([]byte, 1024)
|
||||
err = websocket.Message.Receive(ws, &response)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if g, e := string(request), string(response); g != e {
|
||||
t.Errorf("got body %q; expected %q", g, e)
|
||||
}
|
||||
|
||||
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||
res, _ := http.DefaultClient.Do(getReq)
|
||||
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
||||
backendHostname, _, _ := net.SplitHostPort(backendURL.Host)
|
||||
if g, e := string(bodyBytes), backendHostname; g != e {
|
||||
t.Errorf("got body %q; expected %q", g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReverseProxy(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
hostname, _, _ := net.SplitHostPort(r.Host)
|
||||
_, err := w.Write([]byte(hostname))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(backend.Close)
|
||||
|
||||
backendURL, _ := url.Parse(backend.URL)
|
||||
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
|
||||
backendHost := net.JoinHostPort(backendHostname, backendPort)
|
||||
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
|
||||
|
||||
proxyHandler := NewReverseProxy(proxyURL, &options.Options{FlushInterval: time.Second})
|
||||
setProxyUpstreamHostHeader(proxyHandler, proxyURL)
|
||||
frontend := httptest.NewServer(proxyHandler)
|
||||
t.Cleanup(frontend.Close)
|
||||
|
||||
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||
res, _ := http.DefaultClient.Do(getReq)
|
||||
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
||||
if g, e := string(bodyBytes), backendHostname; g != e {
|
||||
t.Errorf("got body %q; expected %q", g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodedSlashes(t *testing.T) {
|
||||
var seen string
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
seen = r.RequestURI
|
||||
}))
|
||||
t.Cleanup(backend.Close)
|
||||
|
||||
b, _ := url.Parse(backend.URL)
|
||||
proxyHandler := NewReverseProxy(b, &options.Options{FlushInterval: time.Second})
|
||||
setProxyDirector(proxyHandler)
|
||||
frontend := httptest.NewServer(proxyHandler)
|
||||
t.Cleanup(frontend.Close)
|
||||
|
||||
f, _ := url.Parse(frontend.URL)
|
||||
encodedPath := "/a%2Fb/?c=1"
|
||||
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}}
|
||||
_, err := http.DefaultClient.Do(getReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if seen != encodedPath {
|
||||
t.Errorf("got bad request %q expected %q", seen, encodedPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRobotsTxt(t *testing.T) {
|
||||
opts := baseTestOptions()
|
||||
err := validation.Validate(opts)
|
||||
@ -562,7 +424,14 @@ func TestBasicAuthPassword(t *testing.T) {
|
||||
}
|
||||
}))
|
||||
opts := baseTestOptions()
|
||||
opts.Upstreams = append(opts.Upstreams, providerServer.URL)
|
||||
opts.UpstreamServers = options.Upstreams{
|
||||
{
|
||||
ID: providerServer.URL,
|
||||
Path: "/",
|
||||
URI: providerServer.URL,
|
||||
},
|
||||
}
|
||||
|
||||
opts.Cookie.Secure = false
|
||||
opts.PassBasicAuth = true
|
||||
opts.SetBasicAuth = true
|
||||
@ -867,7 +736,7 @@ type PassAccessTokenTest struct {
|
||||
|
||||
type PassAccessTokenTestOptions struct {
|
||||
PassAccessToken bool
|
||||
ProxyUpstream string
|
||||
ProxyUpstream options.Upstream
|
||||
}
|
||||
|
||||
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) {
|
||||
@ -893,10 +762,17 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe
|
||||
}))
|
||||
|
||||
patt.opts = baseTestOptions()
|
||||
patt.opts.Upstreams = append(patt.opts.Upstreams, patt.providerServer.URL)
|
||||
if opts.ProxyUpstream != "" {
|
||||
patt.opts.Upstreams = append(patt.opts.Upstreams, opts.ProxyUpstream)
|
||||
patt.opts.UpstreamServers = options.Upstreams{
|
||||
{
|
||||
ID: patt.providerServer.URL,
|
||||
Path: "/",
|
||||
URI: patt.providerServer.URL,
|
||||
},
|
||||
}
|
||||
if opts.ProxyUpstream.ID != "" {
|
||||
patt.opts.UpstreamServers = append(patt.opts.UpstreamServers, opts.ProxyUpstream)
|
||||
}
|
||||
|
||||
patt.opts.Cookie.Secure = false
|
||||
patt.opts.PassAccessToken = opts.PassAccessToken
|
||||
err := validation.Validate(patt.opts)
|
||||
@ -999,7 +875,11 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
|
||||
func TestStaticProxyUpstream(t *testing.T) {
|
||||
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
||||
PassAccessToken: true,
|
||||
ProxyUpstream: "static://200/static-proxy",
|
||||
ProxyUpstream: options.Upstream{
|
||||
ID: "static-proxy",
|
||||
Path: "/static-proxy",
|
||||
Static: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -1572,7 +1452,13 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
opts := baseTestOptions()
|
||||
opts.Upstreams = append(opts.Upstreams, upstream.URL)
|
||||
opts.UpstreamServers = options.Upstreams{
|
||||
{
|
||||
ID: upstream.URL,
|
||||
Path: "/",
|
||||
URI: upstream.URL,
|
||||
},
|
||||
}
|
||||
opts.SkipAuthPreflight = true
|
||||
err := validation.Validate(opts)
|
||||
assert.NoError(t, err)
|
||||
@ -1641,7 +1527,13 @@ func NewSignatureTest() (*SignatureTest, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts.Upstreams = append(opts.Upstreams, upstream.URL)
|
||||
opts.UpstreamServers = options.Upstreams{
|
||||
{
|
||||
ID: upstream.URL,
|
||||
Path: "/",
|
||||
URI: upstream.URL,
|
||||
},
|
||||
}
|
||||
|
||||
providerHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte(`{"access_token": "my_auth_token"}`))
|
||||
@ -1716,7 +1608,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er
|
||||
}
|
||||
// This is used by the upstream to validate the signature.
|
||||
st.authenticator.auth = hmacauth.NewHmacAuth(
|
||||
crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)
|
||||
crypto.SHA1, []byte(key), upstream.SignatureHeader, upstream.SignatureHeaders)
|
||||
proxy.ServeHTTP(st.rw, req)
|
||||
|
||||
return nil
|
||||
@ -2110,7 +2002,13 @@ func Test_noCacheHeaders(t *testing.T) {
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
opts := baseTestOptions()
|
||||
opts.Upstreams = []string{upstream.URL}
|
||||
opts.UpstreamServers = options.Upstreams{
|
||||
{
|
||||
ID: upstream.URL,
|
||||
Path: "/",
|
||||
URI: upstream.URL,
|
||||
},
|
||||
}
|
||||
opts.SkipAuthRegex = []string{".*"}
|
||||
err := validation.Validate(opts)
|
||||
assert.NoError(t, err)
|
||||
@ -2335,7 +2233,13 @@ func TestTrustedIPs(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := baseTestOptions()
|
||||
opts.Upstreams = []string{"static://200"}
|
||||
opts.UpstreamServers = options.Upstreams{
|
||||
{
|
||||
ID: "static",
|
||||
Path: "/",
|
||||
Static: true,
|
||||
},
|
||||
}
|
||||
opts.TrustedIPs = tt.trustedIPs
|
||||
opts.ReverseProxy = tt.reverseProxy
|
||||
opts.RealClientIPHeader = tt.realClientIPHeader
|
||||
|
Reference in New Issue
Block a user