1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-17 00:17:40 +02:00
This commit is contained in:
Joel Speed
2022-06-03 12:41:30 +01:00
parent 374a676c9d
commit 0dbda5dfac
9 changed files with 520 additions and 493 deletions

View File

@ -1358,7 +1358,13 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
},
},
}
opts.SkipAuthPreflight = true
opts.Authorization.RequestRules = []options.AuthorizationRule{
{
ID: "skip-auth-preflight",
Methods: []string{http.MethodOptions},
Policy: options.AllowPolicy,
},
}
err := validation.Validate(opts)
assert.NoError(t, err)
@ -1889,7 +1895,14 @@ func Test_noCacheHeaders(t *testing.T) {
},
},
}
opts.SkipAuthRegex = []string{".*"}
opts.Authorization.RequestRules = []options.AuthorizationRule{
{
ID: "wildcard",
Path: ".*",
Policy: options.AllowPolicy,
},
}
err := validation.Validate(opts)
assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
@ -2161,7 +2174,16 @@ func TestTrustedIPs(t *testing.T) {
},
},
}
opts.TrustedIPs = tt.trustedIPs
if len(tt.trustedIPs) > 0 {
opts.Authorization.RequestRules = []options.AuthorizationRule{
{
ID: "trusted-ips",
IPs: tt.trustedIPs,
Policy: options.AllowPolicy,
},
}
}
opts.ReverseProxy = tt.reverseProxy
opts.RealClientIPHeader = tt.realClientIPHeader
err := validation.Validate(opts)
@ -2181,160 +2203,160 @@ func TestTrustedIPs(t *testing.T) {
}
}
func Test_buildRoutesAllowlist(t *testing.T) {
type expectedAllowedRoute struct {
method string
regexString string
}
testCases := []struct {
name string
skipAuthRegex []string
skipAuthRoutes []string
expectedRoutes []expectedAllowedRoute
shouldError bool
}{
{
name: "No skip auth configured",
skipAuthRegex: []string{},
skipAuthRoutes: []string{},
expectedRoutes: []expectedAllowedRoute{},
shouldError: false,
},
{
name: "Only skipAuthRegex configured",
skipAuthRegex: []string{
"^/foo/bar",
"^/baz/[0-9]+/thing",
},
skipAuthRoutes: []string{},
expectedRoutes: []expectedAllowedRoute{
{
method: "",
regexString: "^/foo/bar",
},
{
method: "",
regexString: "^/baz/[0-9]+/thing",
},
},
shouldError: false,
},
{
name: "Only skipAuthRoutes configured",
skipAuthRegex: []string{},
skipAuthRoutes: []string{
"GET=^/foo/bar",
"POST=^/baz/[0-9]+/thing",
"^/all/methods$",
"WEIRD=^/methods/are/allowed",
"PATCH=/second/equals?are=handled&just=fine",
},
expectedRoutes: []expectedAllowedRoute{
{
method: "GET",
regexString: "^/foo/bar",
},
{
method: "POST",
regexString: "^/baz/[0-9]+/thing",
},
{
method: "",
regexString: "^/all/methods$",
},
{
method: "WEIRD",
regexString: "^/methods/are/allowed",
},
{
method: "PATCH",
regexString: "/second/equals?are=handled&just=fine",
},
},
shouldError: false,
},
{
name: "Both skipAuthRegexes and skipAuthRoutes configured",
skipAuthRegex: []string{
"^/foo/bar/regex",
"^/baz/[0-9]+/thing/regex",
},
skipAuthRoutes: []string{
"GET=^/foo/bar",
"POST=^/baz/[0-9]+/thing",
"^/all/methods$",
},
expectedRoutes: []expectedAllowedRoute{
{
method: "",
regexString: "^/foo/bar/regex",
},
{
method: "",
regexString: "^/baz/[0-9]+/thing/regex",
},
{
method: "GET",
regexString: "^/foo/bar",
},
{
method: "POST",
regexString: "^/baz/[0-9]+/thing",
},
{
method: "",
regexString: "^/all/methods$",
},
},
shouldError: false,
},
{
name: "Invalid skipAuthRegex entry",
skipAuthRegex: []string{
"^/foo/bar",
"^/baz/[0-9]+/thing",
"(bad[regex",
},
skipAuthRoutes: []string{},
expectedRoutes: []expectedAllowedRoute{},
shouldError: true,
},
{
name: "Invalid skipAuthRoutes entry",
skipAuthRegex: []string{},
skipAuthRoutes: []string{
"GET=^/foo/bar",
"POST=^/baz/[0-9]+/thing",
"^/all/methods$",
"PUT=(bad[regex",
},
expectedRoutes: []expectedAllowedRoute{},
shouldError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
opts := &options.Options{
SkipAuthRegex: tc.skipAuthRegex,
SkipAuthRoutes: tc.skipAuthRoutes,
}
routes, err := buildRoutesAllowlist(opts)
if tc.shouldError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
for i, route := range routes {
assert.Greater(t, len(tc.expectedRoutes), i)
assert.Equal(t, route.method, tc.expectedRoutes[i].method)
assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString)
}
})
}
}
// func Test_buildRoutesAllowlist(t *testing.T) {
// type expectedAllowedRoute struct {
// method string
// regexString string
// }
//
// testCases := []struct {
// name string
// skipAuthRegex []string
// skipAuthRoutes []string
// expectedRoutes []expectedAllowedRoute
// shouldError bool
// }{
// {
// name: "No skip auth configured",
// skipAuthRegex: []string{},
// skipAuthRoutes: []string{},
// expectedRoutes: []expectedAllowedRoute{},
// shouldError: false,
// },
// {
// name: "Only skipAuthRegex configured",
// skipAuthRegex: []string{
// "^/foo/bar",
// "^/baz/[0-9]+/thing",
// },
// skipAuthRoutes: []string{},
// expectedRoutes: []expectedAllowedRoute{
// {
// method: "",
// regexString: "^/foo/bar",
// },
// {
// method: "",
// regexString: "^/baz/[0-9]+/thing",
// },
// },
// shouldError: false,
// },
// {
// name: "Only skipAuthRoutes configured",
// skipAuthRegex: []string{},
// skipAuthRoutes: []string{
// "GET=^/foo/bar",
// "POST=^/baz/[0-9]+/thing",
// "^/all/methods$",
// "WEIRD=^/methods/are/allowed",
// "PATCH=/second/equals?are=handled&just=fine",
// },
// expectedRoutes: []expectedAllowedRoute{
// {
// method: "GET",
// regexString: "^/foo/bar",
// },
// {
// method: "POST",
// regexString: "^/baz/[0-9]+/thing",
// },
// {
// method: "",
// regexString: "^/all/methods$",
// },
// {
// method: "WEIRD",
// regexString: "^/methods/are/allowed",
// },
// {
// method: "PATCH",
// regexString: "/second/equals?are=handled&just=fine",
// },
// },
// shouldError: false,
// },
// {
// name: "Both skipAuthRegexes and skipAuthRoutes configured",
// skipAuthRegex: []string{
// "^/foo/bar/regex",
// "^/baz/[0-9]+/thing/regex",
// },
// skipAuthRoutes: []string{
// "GET=^/foo/bar",
// "POST=^/baz/[0-9]+/thing",
// "^/all/methods$",
// },
// expectedRoutes: []expectedAllowedRoute{
// {
// method: "",
// regexString: "^/foo/bar/regex",
// },
// {
// method: "",
// regexString: "^/baz/[0-9]+/thing/regex",
// },
// {
// method: "GET",
// regexString: "^/foo/bar",
// },
// {
// method: "POST",
// regexString: "^/baz/[0-9]+/thing",
// },
// {
// method: "",
// regexString: "^/all/methods$",
// },
// },
// shouldError: false,
// },
// {
// name: "Invalid skipAuthRegex entry",
// skipAuthRegex: []string{
// "^/foo/bar",
// "^/baz/[0-9]+/thing",
// "(bad[regex",
// },
// skipAuthRoutes: []string{},
// expectedRoutes: []expectedAllowedRoute{},
// shouldError: true,
// },
// {
// name: "Invalid skipAuthRoutes entry",
// skipAuthRegex: []string{},
// skipAuthRoutes: []string{
// "GET=^/foo/bar",
// "POST=^/baz/[0-9]+/thing",
// "^/all/methods$",
// "PUT=(bad[regex",
// },
// expectedRoutes: []expectedAllowedRoute{},
// shouldError: true,
// },
// }
//
// for _, tc := range testCases {
// t.Run(tc.name, func(t *testing.T) {
// opts := &options.Options{
// SkipAuthRegex: tc.skipAuthRegex,
// SkipAuthRoutes: tc.skipAuthRoutes,
// }
// routes, err := buildRoutesAllowlist(opts)
// if tc.shouldError {
// assert.Error(t, err)
// return
// }
// assert.NoError(t, err)
//
// for i, route := range routes {
// assert.Greater(t, len(tc.expectedRoutes), i)
// assert.Equal(t, route.method, tc.expectedRoutes[i].method)
// assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString)
// }
// })
// }
// }
func TestAllowedRequest(t *testing.T) {
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -2356,12 +2378,20 @@ func TestAllowedRequest(t *testing.T) {
},
},
}
opts.SkipAuthRegex = []string{
"^/skip/auth/regex$",
}
opts.SkipAuthRoutes = []string{
"GET=^/skip/auth/routes/get",
opts.Authorization.RequestRules = []options.AuthorizationRule{
{
ID: "regex",
Path: "^/skip/auth/regex$",
Policy: options.AllowPolicy,
},
{
ID: "route",
Path: "^/skip/auth/routes/get",
Methods: []string{http.MethodGet},
Policy: options.AllowPolicy,
},
}
err := validation.Validate(opts)
assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
@ -2417,7 +2447,6 @@ func TestAllowedRequest(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest(tc.method, tc.url, nil)
assert.NoError(t, err)
assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req))
rw := httptest.NewRecorder()
proxy.ServeHTTP(rw, req)
@ -2670,8 +2699,18 @@ func TestAuthOnlyAllowedGroupsWithSkipMethods(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
test, err := NewAuthOnlyEndpointTest("?allowed_groups=a,b", func(opts *options.Options) {
opts.SkipAuthPreflight = true
opts.TrustedIPs = []string{"1.2.3.4"}
opts.Authorization.RequestRules = []options.AuthorizationRule{
{
ID: "skip-auth-preflight",
Methods: []string{http.MethodOptions},
Policy: options.AllowPolicy,
},
{
ID: "trusted-ips",
IPs: []string{"1.2.3.4"},
Policy: options.AllowPolicy,
},
}
})
if err != nil {
t.Fatal(err)