1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2024-11-24 08:52:25 +02:00

Add API route config (#1760)

* Add API route config

In addition to requests with Accept header `application/json` return 401 instead of 302 to login page on requests matching API paths regex.

* Update changelog

* Refactor

* Remove unnecessary comment

* Reorder checks

* Lint Api -> API

Co-authored-by: Sebastian Halder <sebastian.halder@boehringer-ingelheim.com>
This commit is contained in:
Segfault16 2022-09-11 17:09:32 +02:00 committed by GitHub
parent b82593b9cc
commit 965fab422d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 188 additions and 17 deletions

View File

@ -39,6 +39,9 @@ to remain consistent with CLI flags. You should specify `code_challenge_method`
- [#1788](https://github.com/oauth2-proxy/oauth2-proxy/pull/1788) Update base docker image to alpine 3.16 - [#1788](https://github.com/oauth2-proxy/oauth2-proxy/pull/1788) Update base docker image to alpine 3.16
- [#1760](https://github.com/oauth2-proxy/oauth2-proxy/pull/1760) Option to configure API routes
# V7.3.0 # V7.3.0
## Release Highlights ## Release Highlights

View File

@ -69,6 +69,11 @@
# "^/metrics" # "^/metrics"
# ] # ]
## mark paths as API routes to get HTTP Status code 401 instead of redirect to login page
# api_routes = [
# "^/api
# ]
## Templates ## Templates
## optional directory with custom sign_in.html and error.html ## optional directory with custom sign_in.html and error.html
# custom_templates_dir = "" # custom_templates_dir = ""

View File

@ -75,6 +75,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
| Option | Type | Description | Default | | Option | Type | Description | Default |
| ------ | ---- | ----------- | ------- | | ------ | ---- | ----------- | ------- |
| `--acr-values` | string | optional, see [docs](https://openid.net/specs/openid-connect-eap-acr-values-1_0.html#acrValues) | `""` | | `--acr-values` | string | optional, see [docs](https://openid.net/specs/openid-connect-eap-acr-values-1_0.html#acrValues) | `""` |
| `--api-route` | string \| list | return HTTP 401 instead of redirecting to authentication server if token is not valid. Format: path_regex | |
| `--approval-prompt` | string | OAuth approval_prompt | `"force"` | | `--approval-prompt` | string | OAuth approval_prompt | `"force"` |
| `--auth-logging` | bool | Log authentication attempts | true | | `--auth-logging` | bool | Log authentication attempts | true |
| `--auth-logging-format` | string | Template for authentication log lines | see [Logging Configuration](#logging-configuration) | | `--auth-logging-format` | string | Template for authentication log lines | see [Logging Configuration](#logging-configuration) |

View File

@ -68,6 +68,10 @@ type allowedRoute struct {
pathRegex *regexp.Regexp pathRegex *regexp.Regexp
} }
type apiRoute struct {
pathRegex *regexp.Regexp
}
// OAuthProxy is the main authentication proxy // OAuthProxy is the main authentication proxy
type OAuthProxy struct { type OAuthProxy struct {
CookieOptions *options.Cookie CookieOptions *options.Cookie
@ -76,6 +80,7 @@ type OAuthProxy struct {
SignInPath string SignInPath string
allowedRoutes []allowedRoute allowedRoutes []allowedRoute
apiRoutes []apiRoute
redirectURL *url.URL // the url to receive requests at redirectURL *url.URL // the url to receive requests at
whitelistDomains []string whitelistDomains []string
provider providers.Provider provider providers.Provider
@ -176,6 +181,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
return nil, err return nil, err
} }
apiRoutes, err := buildAPIRoutes(opts)
if err != nil {
return nil, err
}
preAuthChain, err := buildPreAuthChain(opts) preAuthChain, err := buildPreAuthChain(opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not build pre-auth chain: %v", err) return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
@ -202,6 +212,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
provider: provider, provider: provider,
sessionStore: sessionStore, sessionStore: sessionStore,
redirectURL: redirectURL, redirectURL: redirectURL,
apiRoutes: apiRoutes,
allowedRoutes: allowedRoutes, allowedRoutes: allowedRoutes,
whitelistDomains: opts.WhitelistDomains, whitelistDomains: opts.WhitelistDomains,
skipAuthPreflight: opts.SkipAuthPreflight, skipAuthPreflight: opts.SkipAuthPreflight,
@ -473,6 +484,24 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) {
return routes, nil return routes, nil
} }
// buildAPIRoutes builds an []apiRoute from ApiRoutes option
func buildAPIRoutes(opts *options.Options) ([]apiRoute, error) {
routes := make([]apiRoute, 0, len(opts.APIRoutes))
for _, path := range opts.APIRoutes {
compiledRegex, err := regexp.Compile(path)
if err != nil {
return nil, err
}
logger.Printf("API route - Path: %s", path)
routes = append(routes, apiRoute{
pathRegex: compiledRegex,
})
}
return routes, nil
}
// ClearSessionCookie creates a cookie to unset the user's authentication cookie // ClearSessionCookie creates a cookie to unset the user's authentication cookie
// stored in the user's session // stored in the user's session
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error {
@ -543,6 +572,15 @@ func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
return false return false
} }
func (p *OAuthProxy) isAPIPath(req *http.Request) bool {
for _, route := range p.apiRoutes {
if route.pathRegex.MatchString(req.URL.Path) {
return true
}
}
return false
}
// isTrustedIP is used to check if a request comes from a trusted client IP address. // isTrustedIP is used to check if a request comes from a trusted client IP address.
func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { func (p *OAuthProxy) isTrustedIP(req *http.Request) bool {
if p.trustedIPs == nil { if p.trustedIPs == nil {
@ -911,7 +949,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
p.headersChain.Then(p.upstreamProxy).ServeHTTP(rw, req) p.headersChain.Then(p.upstreamProxy).ServeHTTP(rw, req)
case ErrNeedsLogin: case ErrNeedsLogin:
// we need to send the user to a login screen // we need to send the user to a login screen
if p.forceJSONErrors || isAjax(req) { if p.forceJSONErrors || isAjax(req) || p.isAPIPath(req) {
logger.Printf("No valid authentication in request. Access Denied.") logger.Printf("No valid authentication in request. Access Denied.")
// no point redirecting an AJAX request // no point redirecting an AJAX request
p.errorJSON(rw, http.StatusUnauthorized) p.errorJSON(rw, http.StatusUnauthorized)

View File

@ -1533,7 +1533,8 @@ func (st *SignatureTest) Close() {
// fakeNetConn simulates an http.Request.Body buffer that will be consumed // fakeNetConn simulates an http.Request.Body buffer that will be consumed
// when it is read by the hmacauth.HmacAuth if not handled properly. See: // when it is read by the hmacauth.HmacAuth if not handled properly. See:
// https://github.com/18F/hmacauth/pull/4 //
// https://github.com/18F/hmacauth/pull/4
type fakeNetConn struct { type fakeNetConn struct {
reqBody string reqBody string
} }
@ -2421,6 +2422,116 @@ func Test_buildRoutesAllowlist(t *testing.T) {
} }
} }
func TestApiRoutes(t *testing.T) {
ajaxAPIServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
_, err := w.Write([]byte("AJAX API Request"))
if err != nil {
t.Fatal(err)
}
}))
t.Cleanup(ajaxAPIServer.Close)
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
_, err := w.Write([]byte("API Request"))
if err != nil {
t.Fatal(err)
}
}))
t.Cleanup(apiServer.Close)
uiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
_, err := w.Write([]byte("API Request"))
if err != nil {
t.Fatal(err)
}
}))
t.Cleanup(uiServer.Close)
opts := baseTestOptions()
opts.UpstreamServers = options.UpstreamConfig{
Upstreams: []options.Upstream{
{
ID: apiServer.URL,
Path: "/api",
URI: apiServer.URL,
},
{
ID: ajaxAPIServer.URL,
Path: "/ajaxapi",
URI: ajaxAPIServer.URL,
},
{
ID: uiServer.URL,
Path: "/ui",
URI: uiServer.URL,
},
},
}
opts.APIRoutes = []string{
"^/api",
}
opts.SkipProviderButton = true
err := validation.Validate(opts)
assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
if err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
contentType string
url string
shouldRedirect bool
}{
{
name: "AJAX request matching API regex",
contentType: "application/json",
url: "/api/v1/UserInfo",
shouldRedirect: false,
},
{
name: "AJAX request not matching API regex",
contentType: "application/json",
url: "/ajaxapi/v1/UserInfo",
shouldRedirect: false,
},
{
name: "Other Request matching API regex",
contentType: "application/grpcwebtext",
url: "/api/v1/UserInfo",
shouldRedirect: false,
},
{
name: "UI request",
contentType: "html",
url: "/ui/index.html",
shouldRedirect: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest("GET", tc.url, nil)
req.Header.Set("Accept", tc.contentType)
assert.NoError(t, err)
rw := httptest.NewRecorder()
proxy.ServeHTTP(rw, req)
if tc.shouldRedirect {
assert.Equal(t, 302, rw.Code)
} else {
assert.Equal(t, 401, rw.Code)
}
})
}
}
func TestAllowedRequest(t *testing.T) { func TestAllowedRequest(t *testing.T) {
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200) w.WriteHeader(200)

View File

@ -50,6 +50,7 @@ type Options struct {
Providers Providers `cfg:",internal"` Providers Providers `cfg:",internal"`
APIRoutes []string `flag:"api-route" cfg:"api_routes"`
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
SkipAuthRoutes []string `flag:"skip-auth-route" cfg:"skip_auth_routes"` SkipAuthRoutes []string `flag:"skip-auth-route" cfg:"skip_auth_routes"`
SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"`
@ -116,6 +117,7 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"")
flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)") flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)")
flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR method!=path_regex. For all methods: path_regex OR !=path_regex") flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR method!=path_regex. For all methods: path_regex OR !=path_regex")
flagSet.StringSlice("api-route", []string{}, "return HTTP 401 instead of redirecting to authentication server if token is not valid. Format: path_regex")
flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start") flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start")
flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests")
flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS providers") flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS providers")

View File

@ -13,8 +13,8 @@ import (
func validateAllowlists(o *options.Options) []string { func validateAllowlists(o *options.Options) []string {
msgs := []string{} msgs := []string{}
msgs = append(msgs, validateRoutes(o)...) msgs = append(msgs, validateAuthRoutes(o)...)
msgs = append(msgs, validateRegexes(o)...) msgs = append(msgs, validateAuthRegexes(o)...)
msgs = append(msgs, validateTrustedIPs(o)...) msgs = append(msgs, validateTrustedIPs(o)...)
if len(o.TrustedIPs) > 0 && o.ReverseProxy { if len(o.TrustedIPs) > 0 && o.ReverseProxy {
@ -27,8 +27,8 @@ func validateAllowlists(o *options.Options) []string {
return msgs return msgs
} }
// validateRoutes validates method=path routes passed with options.SkipAuthRoutes // validateAuthRoutes validates method=path routes passed with options.SkipAuthRoutes
func validateRoutes(o *options.Options) []string { func validateAuthRoutes(o *options.Options) []string {
msgs := []string{} msgs := []string{}
for _, route := range o.SkipAuthRoutes { for _, route := range o.SkipAuthRoutes {
var regex string var regex string
@ -47,15 +47,8 @@ func validateRoutes(o *options.Options) []string {
} }
// validateRegex validates regex paths passed with options.SkipAuthRegex // validateRegex validates regex paths passed with options.SkipAuthRegex
func validateRegexes(o *options.Options) []string { func validateAuthRegexes(o *options.Options) []string {
msgs := []string{} return validateRegexes(o.SkipAuthRegex)
for _, regex := range o.SkipAuthRegex {
_, err := regexp.Compile(regex)
if err != nil {
msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err))
}
}
return msgs
} }
// validateTrustedIPs validates IP/CIDRs for IP based allowlists // validateTrustedIPs validates IP/CIDRs for IP based allowlists
@ -68,3 +61,20 @@ func validateTrustedIPs(o *options.Options) []string {
} }
return msgs return msgs
} }
// validateAPIRoutes validates regex paths passed with options.ApiRoutes
func validateAPIRoutes(o *options.Options) []string {
return validateRegexes(o.APIRoutes)
}
// validateRegexes validates all regexes and returns a list of messages in case of error
func validateRegexes(regexes []string) []string {
msgs := []string{}
for _, regex := range regexes {
_, err := regexp.Compile(regex)
if err != nil {
msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err))
}
}
return msgs
}

View File

@ -29,7 +29,7 @@ var _ = Describe("Allowlist", func() {
opts := &options.Options{ opts := &options.Options{
SkipAuthRoutes: r.routes, SkipAuthRoutes: r.routes,
} }
Expect(validateRoutes(opts)).To(ConsistOf(r.errStrings)) Expect(validateAuthRoutes(opts)).To(ConsistOf(r.errStrings))
}, },
Entry("Valid regex routes", &validateRoutesTableInput{ Entry("Valid regex routes", &validateRoutesTableInput{
routes: []string{ routes: []string{
@ -61,7 +61,7 @@ var _ = Describe("Allowlist", func() {
opts := &options.Options{ opts := &options.Options{
SkipAuthRegex: r.regexes, SkipAuthRegex: r.regexes,
} }
Expect(validateRegexes(opts)).To(ConsistOf(r.errStrings)) Expect(validateAuthRegexes(opts)).To(ConsistOf(r.errStrings))
}, },
Entry("Valid regex routes", &validateRegexesTableInput{ Entry("Valid regex routes", &validateRegexesTableInput{
regexes: []string{ regexes: []string{

View File

@ -25,6 +25,7 @@ func Validate(o *options.Options) error {
msgs = append(msgs, prefixValues("injectRequestHeaders: ", validateHeaders(o.InjectRequestHeaders)...)...) msgs = append(msgs, prefixValues("injectRequestHeaders: ", validateHeaders(o.InjectRequestHeaders)...)...)
msgs = append(msgs, prefixValues("injectResponseHeaders: ", validateHeaders(o.InjectResponseHeaders)...)...) msgs = append(msgs, prefixValues("injectResponseHeaders: ", validateHeaders(o.InjectResponseHeaders)...)...)
msgs = append(msgs, validateProviders(o)...) msgs = append(msgs, validateProviders(o)...)
msgs = append(msgs, validateAPIRoutes(o)...)
msgs = configureLogger(o.Logging, msgs) msgs = configureLogger(o.Logging, msgs)
msgs = parseSignatureKey(o, msgs) msgs = parseSignatureKey(o, msgs)