From cfd3de807c1ea4bc072203438e8239e65040e862 Mon Sep 17 00:00:00 2001
From: Nick Meves <nick.meves@greenhouse.io>
Date: Wed, 23 Sep 2020 20:16:05 -0700
Subject: [PATCH] Add tests for skip auth functionality

---
 oauthproxy.go      |   2 +-
 oauthproxy_test.go | 273 +++++++++++++++++++++++++++++++++++++++++----
 2 files changed, 254 insertions(+), 21 deletions(-)

diff --git a/oauthproxy.go b/oauthproxy.go
index ddeafdcd..c92eb850 100644
--- a/oauthproxy.go
+++ b/oauthproxy.go
@@ -286,7 +286,7 @@ func buildSignInMessage(opts *options.Options) string {
 // SkipAuthRegex option (paths only support) or newer SkipAuthRoutes option
 // (method=path support)
 func buildRoutesAllowlist(opts *options.Options) ([]*allowedRoute, error) {
-	var routes []*allowedRoute
+	routes := make([]*allowedRoute, 0, len(opts.SkipAuthRegex)+len(opts.SkipAuthRoutes))
 
 	for _, path := range opts.SkipAuthRegex {
 		compiledRegex, err := regexp.Compile(path)
diff --git a/oauthproxy_test.go b/oauthproxy_test.go
index d0feec2d..53bc9543 100644
--- a/oauthproxy_test.go
+++ b/oauthproxy_test.go
@@ -1482,28 +1482,28 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
 }
 
 func TestAuthSkippedForPreflightRequests(t *testing.T) {
-	upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+	upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		w.WriteHeader(200)
 		_, err := w.Write([]byte("response"))
 		if err != nil {
 			t.Fatal(err)
 		}
 	}))
-	t.Cleanup(upstream.Close)
+	t.Cleanup(upstreamServer.Close)
 
 	opts := baseTestOptions()
 	opts.UpstreamServers = options.Upstreams{
 		{
-			ID:   upstream.URL,
+			ID:   upstreamServer.URL,
 			Path: "/",
-			URI:  upstream.URL,
+			URI:  upstreamServer.URL,
 		},
 	}
 	opts.SkipAuthPreflight = true
 	err := validation.Validate(opts)
 	assert.NoError(t, err)
 
-	upstreamURL, _ := url.Parse(upstream.URL)
+	upstreamURL, _ := url.Parse(upstreamServer.URL)
 	opts.SetProvider(NewTestProvider(upstreamURL, ""))
 
 	proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
@@ -1561,17 +1561,17 @@ func NewSignatureTest() (*SignatureTest, error) {
 	opts.EmailDomains = []string{"acm.org"}
 
 	authenticator := &SignatureAuthenticator{}
-	upstream := httptest.NewServer(
+	upstreamServer := httptest.NewServer(
 		http.HandlerFunc(authenticator.Authenticate))
-	upstreamURL, err := url.Parse(upstream.URL)
+	upstreamURL, err := url.Parse(upstreamServer.URL)
 	if err != nil {
 		return nil, err
 	}
 	opts.UpstreamServers = options.Upstreams{
 		{
-			ID:   upstream.URL,
+			ID:   upstreamServer.URL,
 			Path: "/",
-			URI:  upstream.URL,
+			URI:  upstreamServer.URL,
 		},
 	}
 
@@ -1590,7 +1590,7 @@ func NewSignatureTest() (*SignatureTest, error) {
 
 	return &SignatureTest{
 		opts,
-		upstream,
+		upstreamServer,
 		upstreamURL.Host,
 		provider,
 		make(http.Header),
@@ -1974,20 +1974,20 @@ func Test_prepareNoCache(t *testing.T) {
 }
 
 func Test_noCacheHeaders(t *testing.T) {
-	upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+	upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		_, err := w.Write([]byte("upstream"))
 		if err != nil {
 			t.Error(err)
 		}
 	}))
-	t.Cleanup(upstream.Close)
+	t.Cleanup(upstreamServer.Close)
 
 	opts := baseTestOptions()
 	opts.UpstreamServers = options.Upstreams{
 		{
-			ID:   upstream.URL,
+			ID:   upstreamServer.URL,
 			Path: "/",
-			URI:  upstream.URL,
+			URI:  upstreamServer.URL,
 		},
 	}
 	opts.SkipAuthRegex = []string{".*"}
@@ -2224,7 +2224,8 @@ func TestTrustedIPs(t *testing.T) {
 			opts.TrustedIPs = tt.trustedIPs
 			opts.ReverseProxy = tt.reverseProxy
 			opts.RealClientIPHeader = tt.realClientIPHeader
-			validation.Validate(opts)
+			err := validation.Validate(opts)
+			assert.NoError(t, err)
 
 			proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
 			assert.NoError(t, err)
@@ -2240,6 +2241,237 @@ func TestTrustedIPs(t *testing.T) {
 	}
 }
 
+func Test_buildRoutesAllowlist(t *testing.T) {
+	testCases := []struct {
+		name            string
+		skipAuthRegex   []string
+		skipAuthRoutes  []string
+		expectedMethods []string
+		expectedRegexes []string
+		shouldError     bool
+	}{
+		{
+			name:            "No skip auth configured",
+			skipAuthRegex:   []string{},
+			skipAuthRoutes:  []string{},
+			expectedMethods: []string{},
+			expectedRegexes: []string{},
+			shouldError:     false,
+		},
+		{
+			name: "Only skipAuthRegex configured",
+			skipAuthRegex: []string{
+				"^/foo/bar",
+				"^/baz/[0-9]+/thing",
+			},
+			skipAuthRoutes: []string{},
+			expectedMethods: []string{
+				"",
+				"",
+			},
+			expectedRegexes: []string{
+				"^/foo/bar",
+				"^/baz/[0-9]+/thing",
+			},
+			shouldError: false,
+		},
+		{
+			name:          "Only skipAuthRoutes configured",
+			skipAuthRegex: []string{},
+			skipAuthRoutes: []string{
+				"GET=^/foo/bar",
+				"POST=^/baz/[0-9]+/thing",
+				"^/all/methods$",
+				"WEIRD=^/methods/are/allowed",
+				"PATCH=/second/equals?are=handled&just=fine",
+			},
+			expectedMethods: []string{
+				"GET",
+				"POST",
+				"",
+				"WEIRD",
+				"PATCH",
+			},
+			expectedRegexes: []string{
+				"^/foo/bar",
+				"^/baz/[0-9]+/thing",
+				"^/all/methods$",
+				"^/methods/are/allowed",
+				"/second/equals?are=handled&just=fine",
+			},
+			shouldError: false,
+		},
+		{
+			name: "Both skipAuthRegexes and skipAuthRoutes configured",
+			skipAuthRegex: []string{
+				"^/foo/bar/regex",
+				"^/baz/[0-9]+/thing/regex",
+			},
+			skipAuthRoutes: []string{
+				"GET=^/foo/bar",
+				"POST=^/baz/[0-9]+/thing",
+				"^/all/methods$",
+			},
+			expectedMethods: []string{
+				"",
+				"",
+				"GET",
+				"POST",
+				"",
+			},
+			expectedRegexes: []string{
+				"^/foo/bar/regex",
+				"^/baz/[0-9]+/thing/regex",
+				"^/foo/bar",
+				"^/baz/[0-9]+/thing",
+				"^/all/methods$",
+			},
+			shouldError: false,
+		},
+		{
+			name: "Invalid skipAuthRegex entry",
+			skipAuthRegex: []string{
+				"^/foo/bar",
+				"^/baz/[0-9]+/thing",
+				"(bad[regex",
+			},
+			skipAuthRoutes:  []string{},
+			expectedMethods: []string{},
+			expectedRegexes: []string{},
+			shouldError:     true,
+		},
+		{
+			name:          "Invalid skipAuthRoutes entry",
+			skipAuthRegex: []string{},
+			skipAuthRoutes: []string{
+				"GET=^/foo/bar",
+				"POST=^/baz/[0-9]+/thing",
+				"^/all/methods$",
+				"PUT=(bad[regex",
+			},
+			expectedMethods: []string{},
+			expectedRegexes: []string{},
+			shouldError:     true,
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			opts := &options.Options{
+				SkipAuthRegex:  tc.skipAuthRegex,
+				SkipAuthRoutes: tc.skipAuthRoutes,
+			}
+			routes, err := buildRoutesAllowlist(opts)
+			if tc.shouldError {
+				assert.Error(t, err)
+				return
+			} else {
+				assert.NoError(t, err)
+			}
+			for i, route := range routes {
+				assert.Greater(t, len(tc.expectedMethods), i)
+				assert.Equal(t, route.method, tc.expectedMethods[i])
+				assert.Greater(t, len(tc.expectedRegexes), i)
+				assert.Equal(t, route.pathRegex.String(), tc.expectedRegexes[i])
+			}
+		})
+	}
+}
+
+func TestAllowedRequest(t *testing.T) {
+	upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(200)
+		_, err := w.Write([]byte("Allowed Request"))
+		if err != nil {
+			t.Fatal(err)
+		}
+	}))
+	t.Cleanup(upstreamServer.Close)
+
+	opts := baseTestOptions()
+	opts.UpstreamServers = options.Upstreams{
+		{
+			ID:   upstreamServer.URL,
+			Path: "/",
+			URI:  upstreamServer.URL,
+		},
+	}
+	opts.SkipAuthRegex = []string{
+		"^/skip/auth/regex$",
+	}
+	opts.SkipAuthRoutes = []string{
+		"GET=^/skip/auth/routes/get",
+	}
+	err := validation.Validate(opts)
+	assert.NoError(t, err)
+	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	testCases := []struct {
+		name    string
+		method  string
+		url     string
+		allowed bool
+	}{
+		{
+			name:    "Regex GET allowed",
+			method:  "GET",
+			url:     "/skip/auth/regex",
+			allowed: true,
+		},
+		{
+			name:    "Regex POST allowed ",
+			method:  "POST",
+			url:     "/skip/auth/regex",
+			allowed: true,
+		},
+		{
+			name:    "Regex denied",
+			method:  "GET",
+			url:     "/wrong/denied",
+			allowed: false,
+		},
+		{
+			name:    "Route allowed",
+			method:  "GET",
+			url:     "/skip/auth/routes/get",
+			allowed: true,
+		},
+		{
+			name:    "Route denied with wrong method",
+			method:  "PATCH",
+			url:     "/skip/auth/routes/get",
+			allowed: false,
+		},
+		{
+			name:    "Route denied with wrong path",
+			method:  "GET",
+			url:     "/skip/auth/routes/wrong/path",
+			allowed: false,
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			req, err := http.NewRequest(tc.method, tc.url, nil)
+			assert.NoError(t, err)
+			assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req))
+
+			rw := httptest.NewRecorder()
+			proxy.ServeHTTP(rw, req)
+
+			if tc.allowed {
+				assert.Equal(t, 200, rw.Code)
+				assert.Equal(t, "Allowed Request", rw.Body.String())
+			} else {
+				assert.Equal(t, 403, rw.Code)
+			}
+		})
+	}
+}
+
 func TestProxyAllowedGroups(t *testing.T) {
 	tests := []struct {
 		name               string
@@ -2265,18 +2497,18 @@ func TestProxyAllowedGroups(t *testing.T) {
 				CreatedAt:   &created,
 			}
 
-			upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 				w.WriteHeader(200)
 			}))
-			t.Cleanup(upstream.Close)
+			t.Cleanup(upstreamServer.Close)
 
 			test, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
 				opts.AllowedGroups = tt.allowedGroups
 				opts.UpstreamServers = options.Upstreams{
 					{
-						ID:   upstream.URL,
+						ID:   upstreamServer.URL,
 						Path: "/",
-						URI:  upstream.URL,
+						URI:  upstreamServer.URL,
 					},
 				}
 			})
@@ -2287,7 +2519,8 @@ func TestProxyAllowedGroups(t *testing.T) {
 			test.req, _ = http.NewRequest("GET", "/", nil)
 
 			test.req.Header.Add("accept", applicationJSON)
-			test.SaveSession(session)
+			err = test.SaveSession(session)
+			assert.NoError(t, err)
 			test.proxy.ServeHTTP(test.rw, test.req)
 
 			if tt.expectUnauthorized {