diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f3512a6..4cc0ac0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ to remain consistent with CLI flags. You should specify `code_challenge_method` - [#1788](https://github.com/oauth2-proxy/oauth2-proxy/pull/1788) Update base docker image to alpine 3.16 +- [#1760](https://github.com/oauth2-proxy/oauth2-proxy/pull/1760) Option to configure API routes + + # V7.3.0 ## Release Highlights diff --git a/contrib/oauth2-proxy.cfg.example b/contrib/oauth2-proxy.cfg.example index cb0e8dcb..216c9c9d 100644 --- a/contrib/oauth2-proxy.cfg.example +++ b/contrib/oauth2-proxy.cfg.example @@ -69,6 +69,11 @@ # "^/metrics" # ] +## mark paths as API routes to get HTTP Status code 401 instead of redirect to login page +# api_routes = [ +# "^/api +# ] + ## Templates ## optional directory with custom sign_in.html and error.html # custom_templates_dir = "" diff --git a/docs/docs/configuration/overview.md b/docs/docs/configuration/overview.md index 76121e90..10e0d810 100644 --- a/docs/docs/configuration/overview.md +++ b/docs/docs/configuration/overview.md @@ -75,6 +75,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | Option | Type | Description | Default | | ------ | ---- | ----------- | ------- | | `--acr-values` | string | optional, see [docs](https://openid.net/specs/openid-connect-eap-acr-values-1_0.html#acrValues) | `""` | +| `--api-route` | string \| list | return HTTP 401 instead of redirecting to authentication server if token is not valid. Format: path_regex | | | `--approval-prompt` | string | OAuth approval_prompt | `"force"` | | `--auth-logging` | bool | Log authentication attempts | true | | `--auth-logging-format` | string | Template for authentication log lines | see [Logging Configuration](#logging-configuration) | diff --git a/oauthproxy.go b/oauthproxy.go index 93a6977f..d11040c3 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -68,6 +68,10 @@ type allowedRoute struct { pathRegex *regexp.Regexp } +type apiRoute struct { + pathRegex *regexp.Regexp +} + // OAuthProxy is the main authentication proxy type OAuthProxy struct { CookieOptions *options.Cookie @@ -76,6 +80,7 @@ type OAuthProxy struct { SignInPath string allowedRoutes []allowedRoute + apiRoutes []apiRoute redirectURL *url.URL // the url to receive requests at whitelistDomains []string provider providers.Provider @@ -176,6 +181,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, err } + apiRoutes, err := buildAPIRoutes(opts) + if err != nil { + return nil, err + } + preAuthChain, err := buildPreAuthChain(opts) if err != nil { return nil, fmt.Errorf("could not build pre-auth chain: %v", err) @@ -202,6 +212,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr provider: provider, sessionStore: sessionStore, redirectURL: redirectURL, + apiRoutes: apiRoutes, allowedRoutes: allowedRoutes, whitelistDomains: opts.WhitelistDomains, skipAuthPreflight: opts.SkipAuthPreflight, @@ -473,6 +484,24 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { return routes, nil } +// buildAPIRoutes builds an []apiRoute from ApiRoutes option +func buildAPIRoutes(opts *options.Options) ([]apiRoute, error) { + routes := make([]apiRoute, 0, len(opts.APIRoutes)) + + for _, path := range opts.APIRoutes { + compiledRegex, err := regexp.Compile(path) + if err != nil { + return nil, err + } + logger.Printf("API route - Path: %s", path) + routes = append(routes, apiRoute{ + pathRegex: compiledRegex, + }) + } + + return routes, nil +} + // ClearSessionCookie creates a cookie to unset the user's authentication cookie // stored in the user's session func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { @@ -543,6 +572,15 @@ func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { return false } +func (p *OAuthProxy) isAPIPath(req *http.Request) bool { + for _, route := range p.apiRoutes { + if route.pathRegex.MatchString(req.URL.Path) { + return true + } + } + return false +} + // isTrustedIP is used to check if a request comes from a trusted client IP address. func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { if p.trustedIPs == nil { @@ -911,7 +949,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { p.headersChain.Then(p.upstreamProxy).ServeHTTP(rw, req) case ErrNeedsLogin: // we need to send the user to a login screen - if p.forceJSONErrors || isAjax(req) { + if p.forceJSONErrors || isAjax(req) || p.isAPIPath(req) { logger.Printf("No valid authentication in request. Access Denied.") // no point redirecting an AJAX request p.errorJSON(rw, http.StatusUnauthorized) diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 657de44a..10229366 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1533,7 +1533,8 @@ func (st *SignatureTest) Close() { // fakeNetConn simulates an http.Request.Body buffer that will be consumed // when it is read by the hmacauth.HmacAuth if not handled properly. See: -// https://github.com/18F/hmacauth/pull/4 +// +// https://github.com/18F/hmacauth/pull/4 type fakeNetConn struct { reqBody string } @@ -2421,6 +2422,116 @@ func Test_buildRoutesAllowlist(t *testing.T) { } } +func TestApiRoutes(t *testing.T) { + + ajaxAPIServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + _, err := w.Write([]byte("AJAX API Request")) + if err != nil { + t.Fatal(err) + } + })) + t.Cleanup(ajaxAPIServer.Close) + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + _, err := w.Write([]byte("API Request")) + if err != nil { + t.Fatal(err) + } + })) + t.Cleanup(apiServer.Close) + + uiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + _, err := w.Write([]byte("API Request")) + if err != nil { + t.Fatal(err) + } + })) + t.Cleanup(uiServer.Close) + + opts := baseTestOptions() + opts.UpstreamServers = options.UpstreamConfig{ + Upstreams: []options.Upstream{ + { + ID: apiServer.URL, + Path: "/api", + URI: apiServer.URL, + }, + { + ID: ajaxAPIServer.URL, + Path: "/ajaxapi", + URI: ajaxAPIServer.URL, + }, + { + ID: uiServer.URL, + Path: "/ui", + URI: uiServer.URL, + }, + }, + } + opts.APIRoutes = []string{ + "^/api", + } + opts.SkipProviderButton = true + err := validation.Validate(opts) + assert.NoError(t, err) + proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + contentType string + url string + shouldRedirect bool + }{ + { + name: "AJAX request matching API regex", + contentType: "application/json", + url: "/api/v1/UserInfo", + shouldRedirect: false, + }, + { + name: "AJAX request not matching API regex", + contentType: "application/json", + url: "/ajaxapi/v1/UserInfo", + shouldRedirect: false, + }, + { + name: "Other Request matching API regex", + contentType: "application/grpcwebtext", + url: "/api/v1/UserInfo", + shouldRedirect: false, + }, + { + name: "UI request", + contentType: "html", + url: "/ui/index.html", + shouldRedirect: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("GET", tc.url, nil) + req.Header.Set("Accept", tc.contentType) + assert.NoError(t, err) + + rw := httptest.NewRecorder() + proxy.ServeHTTP(rw, req) + + if tc.shouldRedirect { + assert.Equal(t, 302, rw.Code) + } else { + assert.Equal(t, 401, rw.Code) + } + }) + } +} + func TestAllowedRequest(t *testing.T) { upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index a52c6fa5..c65f1244 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -50,6 +50,7 @@ type Options struct { Providers Providers `cfg:",internal"` + APIRoutes []string `flag:"api-route" cfg:"api_routes"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` SkipAuthRoutes []string `flag:"skip-auth-route" cfg:"skip_auth_routes"` SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` @@ -116,6 +117,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)") flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR method!=path_regex. For all methods: path_regex OR !=path_regex") + flagSet.StringSlice("api-route", []string{}, "return HTTP 401 instead of redirecting to authentication server if token is not valid. Format: path_regex") flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start") flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS providers") diff --git a/pkg/validation/allowlist.go b/pkg/validation/allowlist.go index 56a3fd4c..7a36027a 100644 --- a/pkg/validation/allowlist.go +++ b/pkg/validation/allowlist.go @@ -13,8 +13,8 @@ import ( func validateAllowlists(o *options.Options) []string { msgs := []string{} - msgs = append(msgs, validateRoutes(o)...) - msgs = append(msgs, validateRegexes(o)...) + msgs = append(msgs, validateAuthRoutes(o)...) + msgs = append(msgs, validateAuthRegexes(o)...) msgs = append(msgs, validateTrustedIPs(o)...) if len(o.TrustedIPs) > 0 && o.ReverseProxy { @@ -27,8 +27,8 @@ func validateAllowlists(o *options.Options) []string { return msgs } -// validateRoutes validates method=path routes passed with options.SkipAuthRoutes -func validateRoutes(o *options.Options) []string { +// validateAuthRoutes validates method=path routes passed with options.SkipAuthRoutes +func validateAuthRoutes(o *options.Options) []string { msgs := []string{} for _, route := range o.SkipAuthRoutes { var regex string @@ -47,15 +47,8 @@ func validateRoutes(o *options.Options) []string { } // validateRegex validates regex paths passed with options.SkipAuthRegex -func validateRegexes(o *options.Options) []string { - msgs := []string{} - for _, regex := range o.SkipAuthRegex { - _, err := regexp.Compile(regex) - if err != nil { - msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err)) - } - } - return msgs +func validateAuthRegexes(o *options.Options) []string { + return validateRegexes(o.SkipAuthRegex) } // validateTrustedIPs validates IP/CIDRs for IP based allowlists @@ -68,3 +61,20 @@ func validateTrustedIPs(o *options.Options) []string { } return msgs } + +// validateAPIRoutes validates regex paths passed with options.ApiRoutes +func validateAPIRoutes(o *options.Options) []string { + return validateRegexes(o.APIRoutes) +} + +// validateRegexes validates all regexes and returns a list of messages in case of error +func validateRegexes(regexes []string) []string { + msgs := []string{} + for _, regex := range regexes { + _, err := regexp.Compile(regex) + if err != nil { + msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err)) + } + } + return msgs +} diff --git a/pkg/validation/allowlist_test.go b/pkg/validation/allowlist_test.go index 4600a718..1519493b 100644 --- a/pkg/validation/allowlist_test.go +++ b/pkg/validation/allowlist_test.go @@ -29,7 +29,7 @@ var _ = Describe("Allowlist", func() { opts := &options.Options{ SkipAuthRoutes: r.routes, } - Expect(validateRoutes(opts)).To(ConsistOf(r.errStrings)) + Expect(validateAuthRoutes(opts)).To(ConsistOf(r.errStrings)) }, Entry("Valid regex routes", &validateRoutesTableInput{ routes: []string{ @@ -61,7 +61,7 @@ var _ = Describe("Allowlist", func() { opts := &options.Options{ SkipAuthRegex: r.regexes, } - Expect(validateRegexes(opts)).To(ConsistOf(r.errStrings)) + Expect(validateAuthRegexes(opts)).To(ConsistOf(r.errStrings)) }, Entry("Valid regex routes", &validateRegexesTableInput{ regexes: []string{ diff --git a/pkg/validation/options.go b/pkg/validation/options.go index cd8f24f9..a3ce0518 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -25,6 +25,7 @@ func Validate(o *options.Options) error { msgs = append(msgs, prefixValues("injectRequestHeaders: ", validateHeaders(o.InjectRequestHeaders)...)...) msgs = append(msgs, prefixValues("injectResponseHeaders: ", validateHeaders(o.InjectResponseHeaders)...)...) msgs = append(msgs, validateProviders(o)...) + msgs = append(msgs, validateAPIRoutes(o)...) msgs = configureLogger(o.Logging, msgs) msgs = parseSignatureKey(o, msgs)