diff --git a/oauthproxy.go b/oauthproxy.go index 308f806b..7e52646c 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -11,7 +11,6 @@ import ( "net/url" "os" "os/signal" - "regexp" "strings" "syscall" "time" @@ -61,12 +60,6 @@ var ( ErrAccessDenied = errors.New("access denied") ) -// allowedRoute manages method + path based allowlists -type allowedRoute struct { - method string - pathRegex *regexp.Regexp -} - // OAuthProxy is the main authentication proxy type OAuthProxy struct { CookieOptions *options.Cookie @@ -74,7 +67,6 @@ type OAuthProxy struct { SignInPath string - allowedRoutes []allowedRoute redirectURL *url.URL // the url to receive requests at whitelistDomains []string provider providers.Provider @@ -83,11 +75,9 @@ type OAuthProxy struct { basicAuthValidator basic.Validator basicAuthGroups []string SkipProviderButton bool - skipAuthPreflight bool skipJwtBearerTokens bool forceJSONErrors bool realClientIPParser ipapi.RealClientIPParser - trustedIPs *ip.NetSet sessionChain alice.Chain headersChain alice.Chain @@ -161,21 +151,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domains:%s path:%s samesite:%s refresh:%s", opts.Cookie.Name, opts.Cookie.Secure, opts.Cookie.HTTPOnly, opts.Cookie.Expire, strings.Join(opts.Cookie.Domains, ","), opts.Cookie.Path, opts.Cookie.SameSite, refresh) - trustedIPs := ip.NewNetSet() - for _, ipStr := range opts.TrustedIPs { - if ipNet := ip.ParseIPNet(ipStr); ipNet != nil { - trustedIPs.AddIPNet(*ipNet) - } else { - return nil, fmt.Errorf("could not parse IP network (%s)", ipStr) - } - } - - allowedRoutes, err := buildRoutesAllowlist(opts) - if err != nil { - return nil, err - } - - preAuthChain, err := buildPreAuthChain(opts) + preAuthChain, err := buildPreAuthChain(opts, pageWriter) if err != nil { return nil, fmt.Errorf("could not build pre-auth chain: %v", err) } @@ -201,14 +177,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr provider: provider, sessionStore: sessionStore, redirectURL: redirectURL, - allowedRoutes: allowedRoutes, whitelistDomains: opts.WhitelistDomains, - skipAuthPreflight: opts.SkipAuthPreflight, skipJwtBearerTokens: opts.SkipJwtBearerTokens, realClientIPParser: opts.GetRealClientIPParser(), SkipProviderButton: opts.SkipProviderButton, forceJSONErrors: opts.ForceJSONErrors, - trustedIPs: trustedIPs, basicAuthValidator: basicAuthValidator, basicAuthGroups: opts.HtpasswdUserGroups, @@ -316,7 +289,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { // buildPreAuthChain constructs a chain that should process every request before // the OAuth2 Proxy authentication logic kicks in. // For example forcing HTTPS or health checks. -func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { +func buildPreAuthChain(opts *options.Options, pageWriter pagewriter.Writer) (alice.Chain, error) { chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) if opts.ForceHTTPS { @@ -351,6 +324,22 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { chain = chain.Append(middleware.NewRequestMetricsWithDefaultRegistry()) + requestAuthorization, err := middleware.NewRequestAuthorization(pageWriter, opts.Authorization.RequestRules, func(req *http.Request) net.IP { + if opts.GetRealClientIPParser() == nil { + host, _ := util.SplitHostPort(req.RemoteAddr) + return net.ParseIP(host) + } + ip, err := opts.GetRealClientIPParser().GetRealClientIP(req.Header) + if err != nil { + return nil + } + return ip + }) + if err != nil { + return alice.Chain{}, fmt.Errorf("error initialising request authorization middleware: %w", err) + } + chain = chain.Append(requestAuthorization) + return chain, nil } @@ -423,53 +412,6 @@ func buildProviderName(p providers.Provider, override string) string { return p.Data().ProviderName } -// buildRoutesAllowlist builds an []allowedRoute list from either the legacy -// SkipAuthRegex option (paths only support) or newer SkipAuthRoutes option -// (method=path support) -func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { - routes := make([]allowedRoute, 0, len(opts.SkipAuthRegex)+len(opts.SkipAuthRoutes)) - - for _, path := range opts.SkipAuthRegex { - compiledRegex, err := regexp.Compile(path) - if err != nil { - return nil, err - } - logger.Printf("Skipping auth - Method: ALL | Path: %s", path) - routes = append(routes, allowedRoute{ - method: "", - pathRegex: compiledRegex, - }) - } - - for _, methodPath := range opts.SkipAuthRoutes { - var ( - method string - path string - ) - - parts := strings.SplitN(methodPath, "=", 2) - if len(parts) == 1 { - method = "" - path = parts[0] - } else { - method = strings.ToUpper(parts[0]) - path = parts[1] - } - - compiledRegex, err := regexp.Compile(path) - if err != nil { - return nil, err - } - logger.Printf("Skipping auth - Method: %s | Path: %s", method, path) - routes = append(routes, allowedRoute{ - method: method, - pathRegex: compiledRegex, - }) - } - - return routes, nil -} - // ClearSessionCookie creates a cookie to unset the user's authentication cookie // stored in the user's session func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { @@ -512,38 +454,8 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code i // IsAllowedRequest is used to check if auth should be skipped for this request func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { - isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" - return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req) -} - -// IsAllowedRoute is used to check if the request method & path is allowed without auth -func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { - for _, route := range p.allowedRoutes { - if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { - return true - } - } - return false -} - -// isTrustedIP is used to check if a request comes from a trusted client IP address. -func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { - if p.trustedIPs == nil { - return false - } - - remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) - if err != nil { - logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) - // Possibly spoofed X-Real-IP header - return false - } - - if remoteAddr == nil { - return false - } - - return p.trustedIPs.Has(remoteAddr) + scope := middlewareapi.GetRequestScope(req) + return scope.Authorization.Policy == middlewareapi.AllowPolicy } // SignInPage writes the sign in template to the response diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 90b27d59..4a04ee96 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1358,7 +1358,13 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { }, }, } - opts.SkipAuthPreflight = true + opts.Authorization.RequestRules = []options.AuthorizationRule{ + { + ID: "skip-auth-preflight", + Methods: []string{http.MethodOptions}, + Policy: options.AllowPolicy, + }, + } err := validation.Validate(opts) assert.NoError(t, err) @@ -1889,7 +1895,14 @@ func Test_noCacheHeaders(t *testing.T) { }, }, } - opts.SkipAuthRegex = []string{".*"} + opts.Authorization.RequestRules = []options.AuthorizationRule{ + { + ID: "wildcard", + Path: ".*", + Policy: options.AllowPolicy, + }, + } + err := validation.Validate(opts) assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) @@ -2161,7 +2174,16 @@ func TestTrustedIPs(t *testing.T) { }, }, } - opts.TrustedIPs = tt.trustedIPs + if len(tt.trustedIPs) > 0 { + opts.Authorization.RequestRules = []options.AuthorizationRule{ + { + ID: "trusted-ips", + IPs: tt.trustedIPs, + Policy: options.AllowPolicy, + }, + } + } + opts.ReverseProxy = tt.reverseProxy opts.RealClientIPHeader = tt.realClientIPHeader err := validation.Validate(opts) @@ -2181,160 +2203,160 @@ func TestTrustedIPs(t *testing.T) { } } -func Test_buildRoutesAllowlist(t *testing.T) { - type expectedAllowedRoute struct { - method string - regexString string - } - - testCases := []struct { - name string - skipAuthRegex []string - skipAuthRoutes []string - expectedRoutes []expectedAllowedRoute - shouldError bool - }{ - { - name: "No skip auth configured", - skipAuthRegex: []string{}, - skipAuthRoutes: []string{}, - expectedRoutes: []expectedAllowedRoute{}, - shouldError: false, - }, - { - name: "Only skipAuthRegex configured", - skipAuthRegex: []string{ - "^/foo/bar", - "^/baz/[0-9]+/thing", - }, - skipAuthRoutes: []string{}, - expectedRoutes: []expectedAllowedRoute{ - { - method: "", - regexString: "^/foo/bar", - }, - { - method: "", - regexString: "^/baz/[0-9]+/thing", - }, - }, - shouldError: false, - }, - { - name: "Only skipAuthRoutes configured", - skipAuthRegex: []string{}, - skipAuthRoutes: []string{ - "GET=^/foo/bar", - "POST=^/baz/[0-9]+/thing", - "^/all/methods$", - "WEIRD=^/methods/are/allowed", - "PATCH=/second/equals?are=handled&just=fine", - }, - expectedRoutes: []expectedAllowedRoute{ - { - method: "GET", - regexString: "^/foo/bar", - }, - { - method: "POST", - regexString: "^/baz/[0-9]+/thing", - }, - { - method: "", - regexString: "^/all/methods$", - }, - { - method: "WEIRD", - regexString: "^/methods/are/allowed", - }, - { - method: "PATCH", - regexString: "/second/equals?are=handled&just=fine", - }, - }, - shouldError: false, - }, - { - name: "Both skipAuthRegexes and skipAuthRoutes configured", - skipAuthRegex: []string{ - "^/foo/bar/regex", - "^/baz/[0-9]+/thing/regex", - }, - skipAuthRoutes: []string{ - "GET=^/foo/bar", - "POST=^/baz/[0-9]+/thing", - "^/all/methods$", - }, - expectedRoutes: []expectedAllowedRoute{ - { - method: "", - regexString: "^/foo/bar/regex", - }, - { - method: "", - regexString: "^/baz/[0-9]+/thing/regex", - }, - { - method: "GET", - regexString: "^/foo/bar", - }, - { - method: "POST", - regexString: "^/baz/[0-9]+/thing", - }, - { - method: "", - regexString: "^/all/methods$", - }, - }, - shouldError: false, - }, - { - name: "Invalid skipAuthRegex entry", - skipAuthRegex: []string{ - "^/foo/bar", - "^/baz/[0-9]+/thing", - "(bad[regex", - }, - skipAuthRoutes: []string{}, - expectedRoutes: []expectedAllowedRoute{}, - shouldError: true, - }, - { - name: "Invalid skipAuthRoutes entry", - skipAuthRegex: []string{}, - skipAuthRoutes: []string{ - "GET=^/foo/bar", - "POST=^/baz/[0-9]+/thing", - "^/all/methods$", - "PUT=(bad[regex", - }, - expectedRoutes: []expectedAllowedRoute{}, - shouldError: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - opts := &options.Options{ - SkipAuthRegex: tc.skipAuthRegex, - SkipAuthRoutes: tc.skipAuthRoutes, - } - routes, err := buildRoutesAllowlist(opts) - if tc.shouldError { - assert.Error(t, err) - return - } - assert.NoError(t, err) - - for i, route := range routes { - assert.Greater(t, len(tc.expectedRoutes), i) - assert.Equal(t, route.method, tc.expectedRoutes[i].method) - assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString) - } - }) - } -} +// func Test_buildRoutesAllowlist(t *testing.T) { +// type expectedAllowedRoute struct { +// method string +// regexString string +// } +// +// testCases := []struct { +// name string +// skipAuthRegex []string +// skipAuthRoutes []string +// expectedRoutes []expectedAllowedRoute +// shouldError bool +// }{ +// { +// name: "No skip auth configured", +// skipAuthRegex: []string{}, +// skipAuthRoutes: []string{}, +// expectedRoutes: []expectedAllowedRoute{}, +// shouldError: false, +// }, +// { +// name: "Only skipAuthRegex configured", +// skipAuthRegex: []string{ +// "^/foo/bar", +// "^/baz/[0-9]+/thing", +// }, +// skipAuthRoutes: []string{}, +// expectedRoutes: []expectedAllowedRoute{ +// { +// method: "", +// regexString: "^/foo/bar", +// }, +// { +// method: "", +// regexString: "^/baz/[0-9]+/thing", +// }, +// }, +// shouldError: false, +// }, +// { +// name: "Only skipAuthRoutes configured", +// skipAuthRegex: []string{}, +// skipAuthRoutes: []string{ +// "GET=^/foo/bar", +// "POST=^/baz/[0-9]+/thing", +// "^/all/methods$", +// "WEIRD=^/methods/are/allowed", +// "PATCH=/second/equals?are=handled&just=fine", +// }, +// expectedRoutes: []expectedAllowedRoute{ +// { +// method: "GET", +// regexString: "^/foo/bar", +// }, +// { +// method: "POST", +// regexString: "^/baz/[0-9]+/thing", +// }, +// { +// method: "", +// regexString: "^/all/methods$", +// }, +// { +// method: "WEIRD", +// regexString: "^/methods/are/allowed", +// }, +// { +// method: "PATCH", +// regexString: "/second/equals?are=handled&just=fine", +// }, +// }, +// shouldError: false, +// }, +// { +// name: "Both skipAuthRegexes and skipAuthRoutes configured", +// skipAuthRegex: []string{ +// "^/foo/bar/regex", +// "^/baz/[0-9]+/thing/regex", +// }, +// skipAuthRoutes: []string{ +// "GET=^/foo/bar", +// "POST=^/baz/[0-9]+/thing", +// "^/all/methods$", +// }, +// expectedRoutes: []expectedAllowedRoute{ +// { +// method: "", +// regexString: "^/foo/bar/regex", +// }, +// { +// method: "", +// regexString: "^/baz/[0-9]+/thing/regex", +// }, +// { +// method: "GET", +// regexString: "^/foo/bar", +// }, +// { +// method: "POST", +// regexString: "^/baz/[0-9]+/thing", +// }, +// { +// method: "", +// regexString: "^/all/methods$", +// }, +// }, +// shouldError: false, +// }, +// { +// name: "Invalid skipAuthRegex entry", +// skipAuthRegex: []string{ +// "^/foo/bar", +// "^/baz/[0-9]+/thing", +// "(bad[regex", +// }, +// skipAuthRoutes: []string{}, +// expectedRoutes: []expectedAllowedRoute{}, +// shouldError: true, +// }, +// { +// name: "Invalid skipAuthRoutes entry", +// skipAuthRegex: []string{}, +// skipAuthRoutes: []string{ +// "GET=^/foo/bar", +// "POST=^/baz/[0-9]+/thing", +// "^/all/methods$", +// "PUT=(bad[regex", +// }, +// expectedRoutes: []expectedAllowedRoute{}, +// shouldError: true, +// }, +// } +// +// for _, tc := range testCases { +// t.Run(tc.name, func(t *testing.T) { +// opts := &options.Options{ +// SkipAuthRegex: tc.skipAuthRegex, +// SkipAuthRoutes: tc.skipAuthRoutes, +// } +// routes, err := buildRoutesAllowlist(opts) +// if tc.shouldError { +// assert.Error(t, err) +// return +// } +// assert.NoError(t, err) +// +// for i, route := range routes { +// assert.Greater(t, len(tc.expectedRoutes), i) +// assert.Equal(t, route.method, tc.expectedRoutes[i].method) +// assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString) +// } +// }) +// } +// } func TestAllowedRequest(t *testing.T) { upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -2356,12 +2378,20 @@ func TestAllowedRequest(t *testing.T) { }, }, } - opts.SkipAuthRegex = []string{ - "^/skip/auth/regex$", - } - opts.SkipAuthRoutes = []string{ - "GET=^/skip/auth/routes/get", + opts.Authorization.RequestRules = []options.AuthorizationRule{ + { + ID: "regex", + Path: "^/skip/auth/regex$", + Policy: options.AllowPolicy, + }, + { + ID: "route", + Path: "^/skip/auth/routes/get", + Methods: []string{http.MethodGet}, + Policy: options.AllowPolicy, + }, } + err := validation.Validate(opts) assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) @@ -2417,7 +2447,6 @@ func TestAllowedRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req, err := http.NewRequest(tc.method, tc.url, nil) assert.NoError(t, err) - assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req)) rw := httptest.NewRecorder() proxy.ServeHTTP(rw, req) @@ -2670,8 +2699,18 @@ func TestAuthOnlyAllowedGroupsWithSkipMethods(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { test, err := NewAuthOnlyEndpointTest("?allowed_groups=a,b", func(opts *options.Options) { - opts.SkipAuthPreflight = true - opts.TrustedIPs = []string{"1.2.3.4"} + opts.Authorization.RequestRules = []options.AuthorizationRule{ + { + ID: "skip-auth-preflight", + Methods: []string{http.MethodOptions}, + Policy: options.AllowPolicy, + }, + { + ID: "trusted-ips", + IPs: []string{"1.2.3.4"}, + Policy: options.AllowPolicy, + }, + } }) if err != nil { t.Fatal(err) diff --git a/pkg/authorization/rules.go b/pkg/authorization/rules.go index 11f97e5f..577e224f 100644 --- a/pkg/authorization/rules.go +++ b/pkg/authorization/rules.go @@ -4,32 +4,24 @@ import ( "net" "net/http" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" ) -type AuthorizationPolicy int - -const ( - NonePolicy AuthorizationPolicy = iota - AllowPolicy - DelegatePolicy - DenyPolicy -) - type RuleSet interface { - MatchesRequest(req *http.Request) AuthorizationPolicy + MatchesRequest(req *http.Request) middlewareapi.AuthorizationPolicy } type rule struct { conditions []condition - policy AuthorizationPolicy + policy middlewareapi.AuthorizationPolicy } -func (r rule) matches(req *http.Request) AuthorizationPolicy { +func (r rule) matches(req *http.Request) middlewareapi.AuthorizationPolicy { for _, condition := range r.conditions { if !condition.matches(req) { // One of the conditions didn't match so this rule does not apply - return NonePolicy + return middlewareapi.OmittedPolicy } } // If all conditions match, return the configured rule policy @@ -60,17 +52,17 @@ func newRule(authRule options.AuthorizationRule, getClientIPFunc func(*http.Requ conditions = append(conditions, condition) } - var policy AuthorizationPolicy + var policy middlewareapi.AuthorizationPolicy switch authRule.Policy { case options.AllowPolicy: - policy = AllowPolicy + policy = middlewareapi.AllowPolicy case options.DelegatePolicy: - policy = DelegatePolicy + policy = middlewareapi.DelegatePolicy case options.DenyPolicy: - policy = DenyPolicy + policy = middlewareapi.DenyPolicy default: // This shouldn't be the case and should be prevented by validation - policy = NonePolicy + policy = middlewareapi.OmittedPolicy } return rule{ @@ -83,15 +75,15 @@ type ruleSet struct { rules []rule } -func (r ruleSet) MatchesRequest(req *http.Request) AuthorizationPolicy { +func (r ruleSet) MatchesRequest(req *http.Request) middlewareapi.AuthorizationPolicy { for _, rule := range r.rules { - if policy := rule.matches(req); policy != NonePolicy { + if policy := rule.matches(req); policy != middlewareapi.OmittedPolicy { // The rule applies to this request, return its policy return policy } } // No rules matched - return NonePolicy + return middlewareapi.OmittedPolicy } func NewRuleSet(requestRules []options.AuthorizationRule, getClientIPFunc func(*http.Request) net.IP) (RuleSet, error) { diff --git a/pkg/authorization/rules_test.go b/pkg/authorization/rules_test.go index 38c007ec..ff8c753f 100644 --- a/pkg/authorization/rules_test.go +++ b/pkg/authorization/rules_test.go @@ -6,10 +6,11 @@ import ( "net/http/httptest" "testing" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" ) -var result AuthorizationPolicy +var result middlewareapi.AuthorizationPolicy func benchmarkRuleSetMatches(ruleCount int, b *testing.B) { rule1 := options.AuthorizationRule{ @@ -53,10 +54,10 @@ func benchmarkRuleSetMatches(ruleCount int, b *testing.B) { req := httptest.NewRequest("GET", "/foo/bar/baz", nil) - var r AuthorizationPolicy + var r middlewareapi.AuthorizationPolicy for n := 0; n < b.N; n++ { r = ruleSet.MatchesRequest(req) - if r != NonePolicy { + if r != middlewareapi.OmittedPolicy { b.Fatal("expected policy not to match") } } diff --git a/pkg/middleware/request_authorization.go b/pkg/middleware/request_authorization.go new file mode 100644 index 00000000..499a73dd --- /dev/null +++ b/pkg/middleware/request_authorization.go @@ -0,0 +1,61 @@ +package middleware + +import ( + "fmt" + "net" + "net/http" + + "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authorization" +) + +func NewRequestAuthorization(writer pagewriter.Writer, requestRules []options.AuthorizationRule, getClientIPFunc func(*http.Request) net.IP) (alice.Constructor, error) { + ruleset, err := authorization.NewRuleSet(requestRules, getClientIPFunc) + if err != nil { + return nil, fmt.Errorf("could not initialise ruleset: %w", err) + } + + ra := &requestAuthorizer{ + ruleset: ruleset, + writer: writer, + } + + return ra.checkRequestAuthorization, nil +} + +type requestAuthorizer struct { + ruleset authorization.RuleSet + writer pagewriter.Writer +} + +func (r *requestAuthorizer) checkRequestAuthorization(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := middlewareapi.GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + if scope.Authorization.Policy != middlewareapi.OmittedPolicy { + // The request was already authorized, pass to the next handler + next.ServeHTTP(rw, req) + return + } + + policy := r.ruleset.MatchesRequest(req) + switch policy { + case middlewareapi.AllowPolicy, middlewareapi.DelegatePolicy: + scope.Authorization.Type = middlewareapi.RequestAuthorization + scope.Authorization.Policy = policy + case middlewareapi.DenyPolicy: + r.writer.WriteErrorPage(rw, pagewriter.ErrorPageOpts{ + Status: http.StatusForbidden, + RequestID: scope.RequestID, + AppError: "Request denied by authorization policy", + Messages: []interface{}{"Request denied by authorization policy"}, + }) + } + + next.ServeHTTP(rw, req) + }) +} diff --git a/pkg/validation/allowlist.go b/pkg/validation/allowlist.go deleted file mode 100644 index 56a3fd4c..00000000 --- a/pkg/validation/allowlist.go +++ /dev/null @@ -1,70 +0,0 @@ -package validation - -import ( - "fmt" - "os" - "regexp" - "strings" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" -) - -func validateAllowlists(o *options.Options) []string { - msgs := []string{} - - msgs = append(msgs, validateRoutes(o)...) - msgs = append(msgs, validateRegexes(o)...) - msgs = append(msgs, validateTrustedIPs(o)...) - - if len(o.TrustedIPs) > 0 && o.ReverseProxy { - _, err := fmt.Fprintln(os.Stderr, "WARNING: mixing --trusted-ip with --reverse-proxy is a potential security vulnerability. An attacker can inject a trusted IP into an X-Real-IP or X-Forwarded-For header if they aren't properly protected outside of oauth2-proxy") - if err != nil { - panic(err) - } - } - - return msgs -} - -// validateRoutes validates method=path routes passed with options.SkipAuthRoutes -func validateRoutes(o *options.Options) []string { - msgs := []string{} - for _, route := range o.SkipAuthRoutes { - var regex string - parts := strings.SplitN(route, "=", 2) - if len(parts) == 1 { - regex = parts[0] - } else { - regex = parts[1] - } - _, err := regexp.Compile(regex) - if err != nil { - msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err)) - } - } - return msgs -} - -// validateRegex validates regex paths passed with options.SkipAuthRegex -func validateRegexes(o *options.Options) []string { - msgs := []string{} - for _, regex := range o.SkipAuthRegex { - _, err := regexp.Compile(regex) - if err != nil { - msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err)) - } - } - return msgs -} - -// validateTrustedIPs validates IP/CIDRs for IP based allowlists -func validateTrustedIPs(o *options.Options) []string { - msgs := []string{} - for i, ipStr := range o.TrustedIPs { - if nil == ip.ParseIPNet(ipStr) { - msgs = append(msgs, fmt.Sprintf("trusted_ips[%d] (%s) could not be recognized", i, ipStr)) - } - } - return msgs -} diff --git a/pkg/validation/allowlist_test.go b/pkg/validation/allowlist_test.go index 4600a718..0a2be154 100644 --- a/pkg/validation/allowlist_test.go +++ b/pkg/validation/allowlist_test.go @@ -1,125 +1,125 @@ package validation -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/ginkgo/extensions/table" - . "github.com/onsi/gomega" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" -) - -var _ = Describe("Allowlist", func() { - type validateRoutesTableInput struct { - routes []string - errStrings []string - } - - type validateRegexesTableInput struct { - regexes []string - errStrings []string - } - - type validateTrustedIPsTableInput struct { - trustedIPs []string - errStrings []string - } - - DescribeTable("validateRoutes", - func(r *validateRoutesTableInput) { - opts := &options.Options{ - SkipAuthRoutes: r.routes, - } - Expect(validateRoutes(opts)).To(ConsistOf(r.errStrings)) - }, - Entry("Valid regex routes", &validateRoutesTableInput{ - routes: []string{ - "/foo", - "POST=/foo/bar", - "PUT=^/foo/bar$", - "DELETE=/crazy/(?:regex)?/[^/]+/stuff$", - }, - errStrings: []string{}, - }), - Entry("Bad regexes do not compile", &validateRoutesTableInput{ - routes: []string{ - "POST=/(foo", - "OPTIONS=/foo/bar)", - "GET=^]/foo/bar[$", - "GET=^]/foo/bar[$", - }, - errStrings: []string{ - "error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`", - "error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`", - "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", - "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", - }, - }), - ) - - DescribeTable("validateRegexes", - func(r *validateRegexesTableInput) { - opts := &options.Options{ - SkipAuthRegex: r.regexes, - } - Expect(validateRegexes(opts)).To(ConsistOf(r.errStrings)) - }, - Entry("Valid regex routes", &validateRegexesTableInput{ - regexes: []string{ - "/foo", - "/foo/bar", - "^/foo/bar$", - "/crazy/(?:regex)?/[^/]+/stuff$", - }, - errStrings: []string{}, - }), - Entry("Bad regexes do not compile", &validateRegexesTableInput{ - regexes: []string{ - "/(foo", - "/foo/bar)", - "^]/foo/bar[$", - "^]/foo/bar[$", - }, - errStrings: []string{ - "error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`", - "error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`", - "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", - "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", - }, - }), - ) - - DescribeTable("validateTrustedIPs", - func(t *validateTrustedIPsTableInput) { - opts := &options.Options{ - TrustedIPs: t.trustedIPs, - } - Expect(validateTrustedIPs(opts)).To(ConsistOf(t.errStrings)) - }, - Entry("Non-overlapping valid IPs", &validateTrustedIPsTableInput{ - trustedIPs: []string{ - "127.0.0.1", - "10.32.0.1/32", - "43.36.201.0/24", - "::1", - "2a12:105:ee7:9234:0:0:0:0/64", - }, - errStrings: []string{}, - }), - Entry("Overlapping valid IPs", &validateTrustedIPsTableInput{ - trustedIPs: []string{ - "135.180.78.199", - "135.180.78.199/32", - "d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4", - "d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4/128", - }, - errStrings: []string{}, - }), - Entry("Invalid IPs", &validateTrustedIPsTableInput{ - trustedIPs: []string{"[::1]", "alkwlkbn/32"}, - errStrings: []string{ - "trusted_ips[0] ([::1]) could not be recognized", - "trusted_ips[1] (alkwlkbn/32) could not be recognized", - }, - }), - ) -}) +// import ( +// . "github.com/onsi/ginkgo" +// . "github.com/onsi/ginkgo/extensions/table" +// . "github.com/onsi/gomega" +// +// "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" +// ) +// +// var _ = Describe("Allowlist", func() { +// type validateRoutesTableInput struct { +// routes []string +// errStrings []string +// } +// +// type validateRegexesTableInput struct { +// regexes []string +// errStrings []string +// } +// +// type validateTrustedIPsTableInput struct { +// trustedIPs []string +// errStrings []string +// } +// +// DescribeTable("validateRoutes", +// func(r *validateRoutesTableInput) { +// opts := &options.Options{ +// SkipAuthRoutes: r.routes, +// } +// Expect(validateRoutes(opts)).To(ConsistOf(r.errStrings)) +// }, +// Entry("Valid regex routes", &validateRoutesTableInput{ +// routes: []string{ +// "/foo", +// "POST=/foo/bar", +// "PUT=^/foo/bar$", +// "DELETE=/crazy/(?:regex)?/[^/]+/stuff$", +// }, +// errStrings: []string{}, +// }), +// Entry("Bad regexes do not compile", &validateRoutesTableInput{ +// routes: []string{ +// "POST=/(foo", +// "OPTIONS=/foo/bar)", +// "GET=^]/foo/bar[$", +// "GET=^]/foo/bar[$", +// }, +// errStrings: []string{ +// "error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`", +// "error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`", +// "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", +// "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", +// }, +// }), +// ) +// +// DescribeTable("validateRegexes", +// func(r *validateRegexesTableInput) { +// opts := &options.Options{ +// SkipAuthRegex: r.regexes, +// } +// Expect(validateRegexes(opts)).To(ConsistOf(r.errStrings)) +// }, +// Entry("Valid regex routes", &validateRegexesTableInput{ +// regexes: []string{ +// "/foo", +// "/foo/bar", +// "^/foo/bar$", +// "/crazy/(?:regex)?/[^/]+/stuff$", +// }, +// errStrings: []string{}, +// }), +// Entry("Bad regexes do not compile", &validateRegexesTableInput{ +// regexes: []string{ +// "/(foo", +// "/foo/bar)", +// "^]/foo/bar[$", +// "^]/foo/bar[$", +// }, +// errStrings: []string{ +// "error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`", +// "error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`", +// "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", +// "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", +// }, +// }), +// ) +// +// DescribeTable("validateTrustedIPs", +// func(t *validateTrustedIPsTableInput) { +// opts := &options.Options{ +// TrustedIPs: t.trustedIPs, +// } +// Expect(validateTrustedIPs(opts)).To(ConsistOf(t.errStrings)) +// }, +// Entry("Non-overlapping valid IPs", &validateTrustedIPsTableInput{ +// trustedIPs: []string{ +// "127.0.0.1", +// "10.32.0.1/32", +// "43.36.201.0/24", +// "::1", +// "2a12:105:ee7:9234:0:0:0:0/64", +// }, +// errStrings: []string{}, +// }), +// Entry("Overlapping valid IPs", &validateTrustedIPsTableInput{ +// trustedIPs: []string{ +// "135.180.78.199", +// "135.180.78.199/32", +// "d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4", +// "d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4/128", +// }, +// errStrings: []string{}, +// }), +// Entry("Invalid IPs", &validateTrustedIPsTableInput{ +// trustedIPs: []string{"[::1]", "alkwlkbn/32"}, +// errStrings: []string{ +// "trusted_ips[0] ([::1]) could not be recognized", +// "trusted_ips[1] (alkwlkbn/32) could not be recognized", +// }, +// }), +// ) +// }) diff --git a/pkg/validation/authorization.go b/pkg/validation/authorization.go new file mode 100644 index 00000000..b13b0eb1 --- /dev/null +++ b/pkg/validation/authorization.go @@ -0,0 +1,94 @@ +package validation + +import ( + "fmt" + "os" + "regexp" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" +) + +func validateAuthorization(authorization options.Authorization, reverseProxy bool) []string { + msgs := []string{} + + msgs = append(msgs, validateRequestRules(authorization.RequestRules, reverseProxy)...) + + return msgs +} + +func validateRequestRules(rules []options.AuthorizationRule, reverseProxy bool) []string { + msgs := []string{} + + ids := make(map[string]struct{}) + + for _, rule := range rules { + msgs = append(msgs, validateRequestRule(ids, rule, reverseProxy)...) + } + + return msgs +} + +func validateRequestRule(ids map[string]struct{}, rule options.AuthorizationRule, reverseProxy bool) []string { + msgs := []string{} + + if rule.ID == "" { + msgs = append(msgs, "request rule has empty ID: IDs are required for all request rules") + } + + if _, ok := ids[rule.ID]; ok { + msgs = append(msgs, fmt.Sprintf("multiple request rules found with ID %q: request rule IDs must be unique", rule.ID)) + } + ids[rule.ID] = struct{}{} + + msgs = append(msgs, validateRequestRulePolicy(rule.ID, rule.Policy)...) + msgs = append(msgs, validateRequestRulePath(rule.ID, rule.Path)...) + msgs = append(msgs, validateRequestRuleIPs(rule.ID, rule.IPs, reverseProxy)...) + + return msgs +} + +func validateRequestRulePolicy(ruleID string, policy options.AuthorizationPolicy) []string { + msgs := []string{} + + switch policy { + case options.AllowPolicy, options.DenyPolicy, options.DelegatePolicy: + // Do nothing for valid options + default: + msgs = append(msgs, fmt.Sprintf("request rule %q has invalid policy (%s): policy must be one of %s, %s or %s", ruleID, policy, options.AllowPolicy, options.DenyPolicy, options.DelegatePolicy)) + } + + return msgs +} + +// validateRequestRulePath validates paths for path/regex based conditions +func validateRequestRulePath(ruleID string, path string) []string { + msgs := []string{} + + _, err := regexp.Compile(path) + if err != nil { + msgs = append(msgs, fmt.Sprintf("error compiling path regex (%s) for rule %q: %v", path, ruleID, err)) + } + + return msgs +} + +// validateRequestRuleIPs validates IP/CIDRs for IP based conditions. +func validateRequestRuleIPs(ruleID string, ips []string, reverseProxy bool) []string { + msgs := []string{} + + if len(ips) > 0 && reverseProxy { + _, err := fmt.Fprintln(os.Stderr, "WARNING: mixing IP authorization with --reverse-proxy is a potential security vulnerability. An attacker can inject a trusted IP into an X-Real-IP or X-Forwarded-For header if they aren't properly protected outside of oauth2-proxy") + if err != nil { + panic(err) + } + } + + for i, ipStr := range ips { + if nil == ip.ParseIPNet(ipStr) { + msgs = append(msgs, fmt.Sprintf("rule %q IP [%d] (%s) could not be recognized", ruleID, i, ipStr)) + } + } + + return msgs +} diff --git a/pkg/validation/options.go b/pkg/validation/options.go index cd8f24f9..9857bbce 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -20,6 +20,7 @@ import ( // are of the correct format func Validate(o *options.Options) error { msgs := validateCookie(o.Cookie) + msgs = append(msgs, validateAuthorization(o.Authorization, o.ReverseProxy)...) msgs = append(msgs, validateSessionCookieMinimal(o)...) msgs = append(msgs, validateRedisSessionStore(o)...) msgs = append(msgs, prefixValues("injectRequestHeaders: ", validateHeaders(o.InjectRequestHeaders)...)...) @@ -96,9 +97,6 @@ func Validate(o *options.Options) error { }) } - // Do this after ReverseProxy validation for TrustedIP coordinated checks - msgs = append(msgs, validateAllowlists(o)...) - if len(msgs) != 0 { return fmt.Errorf("invalid configuration:\n %s", strings.Join(msgs, "\n "))