1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-15 00:15:00 +02:00

Configure OAuth2 Proxy to use new upstreams package and LegacyConfig

This commit is contained in:
Joel Speed
2020-05-26 20:06:27 +01:00
parent e932381ba7
commit 5dbcd73722
3 changed files with 70 additions and 329 deletions

View File

@ -8,7 +8,6 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
@ -24,11 +23,11 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie"
"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream"
"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
"github.com/oauth2-proxy/oauth2-proxy/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
)
const (
@ -44,143 +43,6 @@ func init() {
logger.SetFlags(logger.Lshortfile)
}
type WebSocketOrRestHandler struct {
restHandler http.Handler
wsHandler http.Handler
}
func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" {
h.wsHandler.ServeHTTP(w, r)
} else {
h.restHandler.ServeHTTP(w, r)
}
}
func TestWebSocketProxy(t *testing.T) {
handler := WebSocketOrRestHandler{
restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
hostname, _, _ := net.SplitHostPort(r.Host)
_, err := w.Write([]byte(hostname))
if err != nil {
t.Fatal(err)
}
}),
wsHandler: websocket.Handler(func(ws *websocket.Conn) {
defer func(t *testing.T) {
if err := ws.Close(); err != nil {
t.Fatal(err)
}
}(t)
var data []byte
err := websocket.Message.Receive(ws, &data)
if err != nil {
t.Fatal(err)
}
err = websocket.Message.Send(ws, data)
if err != nil {
t.Fatal(err)
}
}),
}
backend := httptest.NewServer(&handler)
t.Cleanup(backend.Close)
backendURL, _ := url.Parse(backend.URL)
opts := baseTestOptions()
var auth hmacauth.HmacAuth
opts.PassHostHeader = true
proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, opts, auth)
frontend := httptest.NewServer(proxyHandler)
t.Cleanup(frontend.Close)
frontendURL, _ := url.Parse(frontend.URL)
frontendWSURL := "ws://" + frontendURL.Host + "/"
ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/")
if err != nil {
t.Fatal(err)
}
request := []byte("hello, world!")
err = websocket.Message.Send(ws, request)
if err != nil {
t.Fatal(err)
}
var response = make([]byte, 1024)
err = websocket.Message.Receive(ws, &response)
if err != nil {
t.Fatal(err)
}
if g, e := string(request), string(response); g != e {
t.Errorf("got body %q; expected %q", g, e)
}
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
backendHostname, _, _ := net.SplitHostPort(backendURL.Host)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
func TestNewReverseProxy(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
hostname, _, _ := net.SplitHostPort(r.Host)
_, err := w.Write([]byte(hostname))
if err != nil {
t.Fatal(err)
}
}))
t.Cleanup(backend.Close)
backendURL, _ := url.Parse(backend.URL)
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
proxyHandler := NewReverseProxy(proxyURL, &options.Options{FlushInterval: time.Second})
setProxyUpstreamHostHeader(proxyHandler, proxyURL)
frontend := httptest.NewServer(proxyHandler)
t.Cleanup(frontend.Close)
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
func TestEncodedSlashes(t *testing.T) {
var seen string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
seen = r.RequestURI
}))
t.Cleanup(backend.Close)
b, _ := url.Parse(backend.URL)
proxyHandler := NewReverseProxy(b, &options.Options{FlushInterval: time.Second})
setProxyDirector(proxyHandler)
frontend := httptest.NewServer(proxyHandler)
t.Cleanup(frontend.Close)
f, _ := url.Parse(frontend.URL)
encodedPath := "/a%2Fb/?c=1"
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}}
_, err := http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
if seen != encodedPath {
t.Errorf("got bad request %q expected %q", seen, encodedPath)
}
}
func TestRobotsTxt(t *testing.T) {
opts := baseTestOptions()
err := validation.Validate(opts)
@ -562,7 +424,14 @@ func TestBasicAuthPassword(t *testing.T) {
}
}))
opts := baseTestOptions()
opts.Upstreams = append(opts.Upstreams, providerServer.URL)
opts.UpstreamServers = options.Upstreams{
{
ID: providerServer.URL,
Path: "/",
URI: providerServer.URL,
},
}
opts.Cookie.Secure = false
opts.PassBasicAuth = true
opts.SetBasicAuth = true
@ -867,7 +736,7 @@ type PassAccessTokenTest struct {
type PassAccessTokenTestOptions struct {
PassAccessToken bool
ProxyUpstream string
ProxyUpstream options.Upstream
}
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) {
@ -893,10 +762,17 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe
}))
patt.opts = baseTestOptions()
patt.opts.Upstreams = append(patt.opts.Upstreams, patt.providerServer.URL)
if opts.ProxyUpstream != "" {
patt.opts.Upstreams = append(patt.opts.Upstreams, opts.ProxyUpstream)
patt.opts.UpstreamServers = options.Upstreams{
{
ID: patt.providerServer.URL,
Path: "/",
URI: patt.providerServer.URL,
},
}
if opts.ProxyUpstream.ID != "" {
patt.opts.UpstreamServers = append(patt.opts.UpstreamServers, opts.ProxyUpstream)
}
patt.opts.Cookie.Secure = false
patt.opts.PassAccessToken = opts.PassAccessToken
err := validation.Validate(patt.opts)
@ -999,7 +875,11 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
func TestStaticProxyUpstream(t *testing.T) {
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true,
ProxyUpstream: "static://200/static-proxy",
ProxyUpstream: options.Upstream{
ID: "static-proxy",
Path: "/static-proxy",
Static: true,
},
})
if err != nil {
t.Fatal(err)
@ -1572,7 +1452,13 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
t.Cleanup(upstream.Close)
opts := baseTestOptions()
opts.Upstreams = append(opts.Upstreams, upstream.URL)
opts.UpstreamServers = options.Upstreams{
{
ID: upstream.URL,
Path: "/",
URI: upstream.URL,
},
}
opts.SkipAuthPreflight = true
err := validation.Validate(opts)
assert.NoError(t, err)
@ -1641,7 +1527,13 @@ func NewSignatureTest() (*SignatureTest, error) {
if err != nil {
return nil, err
}
opts.Upstreams = append(opts.Upstreams, upstream.URL)
opts.UpstreamServers = options.Upstreams{
{
ID: upstream.URL,
Path: "/",
URI: upstream.URL,
},
}
providerHandler := func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte(`{"access_token": "my_auth_token"}`))
@ -1716,7 +1608,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er
}
// This is used by the upstream to validate the signature.
st.authenticator.auth = hmacauth.NewHmacAuth(
crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)
crypto.SHA1, []byte(key), upstream.SignatureHeader, upstream.SignatureHeaders)
proxy.ServeHTTP(st.rw, req)
return nil
@ -2110,7 +2002,13 @@ func Test_noCacheHeaders(t *testing.T) {
t.Cleanup(upstream.Close)
opts := baseTestOptions()
opts.Upstreams = []string{upstream.URL}
opts.UpstreamServers = options.Upstreams{
{
ID: upstream.URL,
Path: "/",
URI: upstream.URL,
},
}
opts.SkipAuthRegex = []string{".*"}
err := validation.Validate(opts)
assert.NoError(t, err)
@ -2335,7 +2233,13 @@ func TestTrustedIPs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := baseTestOptions()
opts.Upstreams = []string{"static://200"}
opts.UpstreamServers = options.Upstreams{
{
ID: "static",
Path: "/",
Static: true,
},
}
opts.TrustedIPs = tt.trustedIPs
opts.ReverseProxy = tt.reverseProxy
opts.RealClientIPHeader = tt.realClientIPHeader