diff --git a/oauthproxy.go b/oauthproxy.go index ddeafdcd..c92eb850 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -286,7 +286,7 @@ func buildSignInMessage(opts *options.Options) string { // SkipAuthRegex option (paths only support) or newer SkipAuthRoutes option // (method=path support) func buildRoutesAllowlist(opts *options.Options) ([]*allowedRoute, error) { - var routes []*allowedRoute + routes := make([]*allowedRoute, 0, len(opts.SkipAuthRegex)+len(opts.SkipAuthRoutes)) for _, path := range opts.SkipAuthRegex { compiledRegex, err := regexp.Compile(path) diff --git a/oauthproxy_test.go b/oauthproxy_test.go index d0feec2d..53bc9543 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1482,28 +1482,28 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { } func TestAuthSkippedForPreflightRequests(t *testing.T) { - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) _, err := w.Write([]byte("response")) if err != nil { t.Fatal(err) } })) - t.Cleanup(upstream.Close) + t.Cleanup(upstreamServer.Close) opts := baseTestOptions() opts.UpstreamServers = options.Upstreams{ { - ID: upstream.URL, + ID: upstreamServer.URL, Path: "/", - URI: upstream.URL, + URI: upstreamServer.URL, }, } opts.SkipAuthPreflight = true err := validation.Validate(opts) assert.NoError(t, err) - upstreamURL, _ := url.Parse(upstream.URL) + upstreamURL, _ := url.Parse(upstreamServer.URL) opts.SetProvider(NewTestProvider(upstreamURL, "")) proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) @@ -1561,17 +1561,17 @@ func NewSignatureTest() (*SignatureTest, error) { opts.EmailDomains = []string{"acm.org"} authenticator := &SignatureAuthenticator{} - upstream := httptest.NewServer( + upstreamServer := httptest.NewServer( http.HandlerFunc(authenticator.Authenticate)) - upstreamURL, err := url.Parse(upstream.URL) + upstreamURL, err := url.Parse(upstreamServer.URL) if err != nil { return nil, err } opts.UpstreamServers = options.Upstreams{ { - ID: upstream.URL, + ID: upstreamServer.URL, Path: "/", - URI: upstream.URL, + URI: upstreamServer.URL, }, } @@ -1590,7 +1590,7 @@ func NewSignatureTest() (*SignatureTest, error) { return &SignatureTest{ opts, - upstream, + upstreamServer, upstreamURL.Host, provider, make(http.Header), @@ -1974,20 +1974,20 @@ func Test_prepareNoCache(t *testing.T) { } func Test_noCacheHeaders(t *testing.T) { - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("upstream")) if err != nil { t.Error(err) } })) - t.Cleanup(upstream.Close) + t.Cleanup(upstreamServer.Close) opts := baseTestOptions() opts.UpstreamServers = options.Upstreams{ { - ID: upstream.URL, + ID: upstreamServer.URL, Path: "/", - URI: upstream.URL, + URI: upstreamServer.URL, }, } opts.SkipAuthRegex = []string{".*"} @@ -2224,7 +2224,8 @@ func TestTrustedIPs(t *testing.T) { opts.TrustedIPs = tt.trustedIPs opts.ReverseProxy = tt.reverseProxy opts.RealClientIPHeader = tt.realClientIPHeader - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) assert.NoError(t, err) @@ -2240,6 +2241,237 @@ func TestTrustedIPs(t *testing.T) { } } +func Test_buildRoutesAllowlist(t *testing.T) { + testCases := []struct { + name string + skipAuthRegex []string + skipAuthRoutes []string + expectedMethods []string + expectedRegexes []string + shouldError bool + }{ + { + name: "No skip auth configured", + skipAuthRegex: []string{}, + skipAuthRoutes: []string{}, + expectedMethods: []string{}, + expectedRegexes: []string{}, + shouldError: false, + }, + { + name: "Only skipAuthRegex configured", + skipAuthRegex: []string{ + "^/foo/bar", + "^/baz/[0-9]+/thing", + }, + skipAuthRoutes: []string{}, + expectedMethods: []string{ + "", + "", + }, + expectedRegexes: []string{ + "^/foo/bar", + "^/baz/[0-9]+/thing", + }, + shouldError: false, + }, + { + name: "Only skipAuthRoutes configured", + skipAuthRegex: []string{}, + skipAuthRoutes: []string{ + "GET=^/foo/bar", + "POST=^/baz/[0-9]+/thing", + "^/all/methods$", + "WEIRD=^/methods/are/allowed", + "PATCH=/second/equals?are=handled&just=fine", + }, + expectedMethods: []string{ + "GET", + "POST", + "", + "WEIRD", + "PATCH", + }, + expectedRegexes: []string{ + "^/foo/bar", + "^/baz/[0-9]+/thing", + "^/all/methods$", + "^/methods/are/allowed", + "/second/equals?are=handled&just=fine", + }, + shouldError: false, + }, + { + name: "Both skipAuthRegexes and skipAuthRoutes configured", + skipAuthRegex: []string{ + "^/foo/bar/regex", + "^/baz/[0-9]+/thing/regex", + }, + skipAuthRoutes: []string{ + "GET=^/foo/bar", + "POST=^/baz/[0-9]+/thing", + "^/all/methods$", + }, + expectedMethods: []string{ + "", + "", + "GET", + "POST", + "", + }, + expectedRegexes: []string{ + "^/foo/bar/regex", + "^/baz/[0-9]+/thing/regex", + "^/foo/bar", + "^/baz/[0-9]+/thing", + "^/all/methods$", + }, + shouldError: false, + }, + { + name: "Invalid skipAuthRegex entry", + skipAuthRegex: []string{ + "^/foo/bar", + "^/baz/[0-9]+/thing", + "(bad[regex", + }, + skipAuthRoutes: []string{}, + expectedMethods: []string{}, + expectedRegexes: []string{}, + shouldError: true, + }, + { + name: "Invalid skipAuthRoutes entry", + skipAuthRegex: []string{}, + skipAuthRoutes: []string{ + "GET=^/foo/bar", + "POST=^/baz/[0-9]+/thing", + "^/all/methods$", + "PUT=(bad[regex", + }, + expectedMethods: []string{}, + expectedRegexes: []string{}, + shouldError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts := &options.Options{ + SkipAuthRegex: tc.skipAuthRegex, + SkipAuthRoutes: tc.skipAuthRoutes, + } + routes, err := buildRoutesAllowlist(opts) + if tc.shouldError { + assert.Error(t, err) + return + } else { + assert.NoError(t, err) + } + for i, route := range routes { + assert.Greater(t, len(tc.expectedMethods), i) + assert.Equal(t, route.method, tc.expectedMethods[i]) + assert.Greater(t, len(tc.expectedRegexes), i) + assert.Equal(t, route.pathRegex.String(), tc.expectedRegexes[i]) + } + }) + } +} + +func TestAllowedRequest(t *testing.T) { + upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + _, err := w.Write([]byte("Allowed Request")) + if err != nil { + t.Fatal(err) + } + })) + t.Cleanup(upstreamServer.Close) + + opts := baseTestOptions() + opts.UpstreamServers = options.Upstreams{ + { + ID: upstreamServer.URL, + Path: "/", + URI: upstreamServer.URL, + }, + } + opts.SkipAuthRegex = []string{ + "^/skip/auth/regex$", + } + opts.SkipAuthRoutes = []string{ + "GET=^/skip/auth/routes/get", + } + err := validation.Validate(opts) + assert.NoError(t, err) + proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + method string + url string + allowed bool + }{ + { + name: "Regex GET allowed", + method: "GET", + url: "/skip/auth/regex", + allowed: true, + }, + { + name: "Regex POST allowed ", + method: "POST", + url: "/skip/auth/regex", + allowed: true, + }, + { + name: "Regex denied", + method: "GET", + url: "/wrong/denied", + allowed: false, + }, + { + name: "Route allowed", + method: "GET", + url: "/skip/auth/routes/get", + allowed: true, + }, + { + name: "Route denied with wrong method", + method: "PATCH", + url: "/skip/auth/routes/get", + allowed: false, + }, + { + name: "Route denied with wrong path", + method: "GET", + url: "/skip/auth/routes/wrong/path", + allowed: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(tc.method, tc.url, nil) + assert.NoError(t, err) + assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req)) + + rw := httptest.NewRecorder() + proxy.ServeHTTP(rw, req) + + if tc.allowed { + assert.Equal(t, 200, rw.Code) + assert.Equal(t, "Allowed Request", rw.Body.String()) + } else { + assert.Equal(t, 403, rw.Code) + } + }) + } +} + func TestProxyAllowedGroups(t *testing.T) { tests := []struct { name string @@ -2265,18 +2497,18 @@ func TestProxyAllowedGroups(t *testing.T) { CreatedAt: &created, } - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) })) - t.Cleanup(upstream.Close) + t.Cleanup(upstreamServer.Close) test, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.AllowedGroups = tt.allowedGroups opts.UpstreamServers = options.Upstreams{ { - ID: upstream.URL, + ID: upstreamServer.URL, Path: "/", - URI: upstream.URL, + URI: upstreamServer.URL, }, } }) @@ -2287,7 +2519,8 @@ func TestProxyAllowedGroups(t *testing.T) { test.req, _ = http.NewRequest("GET", "/", nil) test.req.Header.Add("accept", applicationJSON) - test.SaveSession(session) + err = test.SaveSession(session) + assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) if tt.expectUnauthorized {