1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-05-29 23:17:38 +02:00
This commit is contained in:
Joel Speed 2022-06-03 12:41:30 +01:00
parent 374a676c9d
commit 0dbda5dfac
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
9 changed files with 520 additions and 493 deletions

View File

@ -11,7 +11,6 @@ import (
"net/url" "net/url"
"os" "os"
"os/signal" "os/signal"
"regexp"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@ -61,12 +60,6 @@ var (
ErrAccessDenied = errors.New("access denied") ErrAccessDenied = errors.New("access denied")
) )
// allowedRoute manages method + path based allowlists
type allowedRoute struct {
method string
pathRegex *regexp.Regexp
}
// OAuthProxy is the main authentication proxy // OAuthProxy is the main authentication proxy
type OAuthProxy struct { type OAuthProxy struct {
CookieOptions *options.Cookie CookieOptions *options.Cookie
@ -74,7 +67,6 @@ type OAuthProxy struct {
SignInPath string SignInPath string
allowedRoutes []allowedRoute
redirectURL *url.URL // the url to receive requests at redirectURL *url.URL // the url to receive requests at
whitelistDomains []string whitelistDomains []string
provider providers.Provider provider providers.Provider
@ -83,11 +75,9 @@ type OAuthProxy struct {
basicAuthValidator basic.Validator basicAuthValidator basic.Validator
basicAuthGroups []string basicAuthGroups []string
SkipProviderButton bool SkipProviderButton bool
skipAuthPreflight bool
skipJwtBearerTokens bool skipJwtBearerTokens bool
forceJSONErrors bool forceJSONErrors bool
realClientIPParser ipapi.RealClientIPParser realClientIPParser ipapi.RealClientIPParser
trustedIPs *ip.NetSet
sessionChain alice.Chain sessionChain alice.Chain
headersChain alice.Chain headersChain alice.Chain
@ -161,21 +151,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domains:%s path:%s samesite:%s refresh:%s", opts.Cookie.Name, opts.Cookie.Secure, opts.Cookie.HTTPOnly, opts.Cookie.Expire, strings.Join(opts.Cookie.Domains, ","), opts.Cookie.Path, opts.Cookie.SameSite, refresh) logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domains:%s path:%s samesite:%s refresh:%s", opts.Cookie.Name, opts.Cookie.Secure, opts.Cookie.HTTPOnly, opts.Cookie.Expire, strings.Join(opts.Cookie.Domains, ","), opts.Cookie.Path, opts.Cookie.SameSite, refresh)
trustedIPs := ip.NewNetSet() preAuthChain, err := buildPreAuthChain(opts, pageWriter)
for _, ipStr := range opts.TrustedIPs {
if ipNet := ip.ParseIPNet(ipStr); ipNet != nil {
trustedIPs.AddIPNet(*ipNet)
} else {
return nil, fmt.Errorf("could not parse IP network (%s)", ipStr)
}
}
allowedRoutes, err := buildRoutesAllowlist(opts)
if err != nil {
return nil, err
}
preAuthChain, err := buildPreAuthChain(opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not build pre-auth chain: %v", err) return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
} }
@ -201,14 +177,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
provider: provider, provider: provider,
sessionStore: sessionStore, sessionStore: sessionStore,
redirectURL: redirectURL, redirectURL: redirectURL,
allowedRoutes: allowedRoutes,
whitelistDomains: opts.WhitelistDomains, whitelistDomains: opts.WhitelistDomains,
skipAuthPreflight: opts.SkipAuthPreflight,
skipJwtBearerTokens: opts.SkipJwtBearerTokens, skipJwtBearerTokens: opts.SkipJwtBearerTokens,
realClientIPParser: opts.GetRealClientIPParser(), realClientIPParser: opts.GetRealClientIPParser(),
SkipProviderButton: opts.SkipProviderButton, SkipProviderButton: opts.SkipProviderButton,
forceJSONErrors: opts.ForceJSONErrors, forceJSONErrors: opts.ForceJSONErrors,
trustedIPs: trustedIPs,
basicAuthValidator: basicAuthValidator, basicAuthValidator: basicAuthValidator,
basicAuthGroups: opts.HtpasswdUserGroups, basicAuthGroups: opts.HtpasswdUserGroups,
@ -316,7 +289,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) {
// buildPreAuthChain constructs a chain that should process every request before // buildPreAuthChain constructs a chain that should process every request before
// the OAuth2 Proxy authentication logic kicks in. // the OAuth2 Proxy authentication logic kicks in.
// For example forcing HTTPS or health checks. // For example forcing HTTPS or health checks.
func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { func buildPreAuthChain(opts *options.Options, pageWriter pagewriter.Writer) (alice.Chain, error) {
chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader))
if opts.ForceHTTPS { if opts.ForceHTTPS {
@ -351,6 +324,22 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
chain = chain.Append(middleware.NewRequestMetricsWithDefaultRegistry()) chain = chain.Append(middleware.NewRequestMetricsWithDefaultRegistry())
requestAuthorization, err := middleware.NewRequestAuthorization(pageWriter, opts.Authorization.RequestRules, func(req *http.Request) net.IP {
if opts.GetRealClientIPParser() == nil {
host, _ := util.SplitHostPort(req.RemoteAddr)
return net.ParseIP(host)
}
ip, err := opts.GetRealClientIPParser().GetRealClientIP(req.Header)
if err != nil {
return nil
}
return ip
})
if err != nil {
return alice.Chain{}, fmt.Errorf("error initialising request authorization middleware: %w", err)
}
chain = chain.Append(requestAuthorization)
return chain, nil return chain, nil
} }
@ -423,53 +412,6 @@ func buildProviderName(p providers.Provider, override string) string {
return p.Data().ProviderName return p.Data().ProviderName
} }
// buildRoutesAllowlist builds an []allowedRoute list from either the legacy
// SkipAuthRegex option (paths only support) or newer SkipAuthRoutes option
// (method=path support)
func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) {
routes := make([]allowedRoute, 0, len(opts.SkipAuthRegex)+len(opts.SkipAuthRoutes))
for _, path := range opts.SkipAuthRegex {
compiledRegex, err := regexp.Compile(path)
if err != nil {
return nil, err
}
logger.Printf("Skipping auth - Method: ALL | Path: %s", path)
routes = append(routes, allowedRoute{
method: "",
pathRegex: compiledRegex,
})
}
for _, methodPath := range opts.SkipAuthRoutes {
var (
method string
path string
)
parts := strings.SplitN(methodPath, "=", 2)
if len(parts) == 1 {
method = ""
path = parts[0]
} else {
method = strings.ToUpper(parts[0])
path = parts[1]
}
compiledRegex, err := regexp.Compile(path)
if err != nil {
return nil, err
}
logger.Printf("Skipping auth - Method: %s | Path: %s", method, path)
routes = append(routes, allowedRoute{
method: method,
pathRegex: compiledRegex,
})
}
return routes, nil
}
// ClearSessionCookie creates a cookie to unset the user's authentication cookie // ClearSessionCookie creates a cookie to unset the user's authentication cookie
// stored in the user's session // stored in the user's session
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error {
@ -512,38 +454,8 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code i
// IsAllowedRequest is used to check if auth should be skipped for this request // IsAllowedRequest is used to check if auth should be skipped for this request
func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" scope := middlewareapi.GetRequestScope(req)
return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req) return scope.Authorization.Policy == middlewareapi.AllowPolicy
}
// IsAllowedRoute is used to check if the request method & path is allowed without auth
func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
for _, route := range p.allowedRoutes {
if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
return true
}
}
return false
}
// isTrustedIP is used to check if a request comes from a trusted client IP address.
func (p *OAuthProxy) isTrustedIP(req *http.Request) bool {
if p.trustedIPs == nil {
return false
}
remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
if err != nil {
logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
// Possibly spoofed X-Real-IP header
return false
}
if remoteAddr == nil {
return false
}
return p.trustedIPs.Has(remoteAddr)
} }
// SignInPage writes the sign in template to the response // SignInPage writes the sign in template to the response

View File

@ -1358,7 +1358,13 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
}, },
}, },
} }
opts.SkipAuthPreflight = true opts.Authorization.RequestRules = []options.AuthorizationRule{
{
ID: "skip-auth-preflight",
Methods: []string{http.MethodOptions},
Policy: options.AllowPolicy,
},
}
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
@ -1889,7 +1895,14 @@ func Test_noCacheHeaders(t *testing.T) {
}, },
}, },
} }
opts.SkipAuthRegex = []string{".*"} opts.Authorization.RequestRules = []options.AuthorizationRule{
{
ID: "wildcard",
Path: ".*",
Policy: options.AllowPolicy,
},
}
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
@ -2161,7 +2174,16 @@ func TestTrustedIPs(t *testing.T) {
}, },
}, },
} }
opts.TrustedIPs = tt.trustedIPs if len(tt.trustedIPs) > 0 {
opts.Authorization.RequestRules = []options.AuthorizationRule{
{
ID: "trusted-ips",
IPs: tt.trustedIPs,
Policy: options.AllowPolicy,
},
}
}
opts.ReverseProxy = tt.reverseProxy opts.ReverseProxy = tt.reverseProxy
opts.RealClientIPHeader = tt.realClientIPHeader opts.RealClientIPHeader = tt.realClientIPHeader
err := validation.Validate(opts) err := validation.Validate(opts)
@ -2181,160 +2203,160 @@ func TestTrustedIPs(t *testing.T) {
} }
} }
func Test_buildRoutesAllowlist(t *testing.T) { // func Test_buildRoutesAllowlist(t *testing.T) {
type expectedAllowedRoute struct { // type expectedAllowedRoute struct {
method string // method string
regexString string // regexString string
} // }
//
testCases := []struct { // testCases := []struct {
name string // name string
skipAuthRegex []string // skipAuthRegex []string
skipAuthRoutes []string // skipAuthRoutes []string
expectedRoutes []expectedAllowedRoute // expectedRoutes []expectedAllowedRoute
shouldError bool // shouldError bool
}{ // }{
{ // {
name: "No skip auth configured", // name: "No skip auth configured",
skipAuthRegex: []string{}, // skipAuthRegex: []string{},
skipAuthRoutes: []string{}, // skipAuthRoutes: []string{},
expectedRoutes: []expectedAllowedRoute{}, // expectedRoutes: []expectedAllowedRoute{},
shouldError: false, // shouldError: false,
}, // },
{ // {
name: "Only skipAuthRegex configured", // name: "Only skipAuthRegex configured",
skipAuthRegex: []string{ // skipAuthRegex: []string{
"^/foo/bar", // "^/foo/bar",
"^/baz/[0-9]+/thing", // "^/baz/[0-9]+/thing",
}, // },
skipAuthRoutes: []string{}, // skipAuthRoutes: []string{},
expectedRoutes: []expectedAllowedRoute{ // expectedRoutes: []expectedAllowedRoute{
{ // {
method: "", // method: "",
regexString: "^/foo/bar", // regexString: "^/foo/bar",
}, // },
{ // {
method: "", // method: "",
regexString: "^/baz/[0-9]+/thing", // regexString: "^/baz/[0-9]+/thing",
}, // },
}, // },
shouldError: false, // shouldError: false,
}, // },
{ // {
name: "Only skipAuthRoutes configured", // name: "Only skipAuthRoutes configured",
skipAuthRegex: []string{}, // skipAuthRegex: []string{},
skipAuthRoutes: []string{ // skipAuthRoutes: []string{
"GET=^/foo/bar", // "GET=^/foo/bar",
"POST=^/baz/[0-9]+/thing", // "POST=^/baz/[0-9]+/thing",
"^/all/methods$", // "^/all/methods$",
"WEIRD=^/methods/are/allowed", // "WEIRD=^/methods/are/allowed",
"PATCH=/second/equals?are=handled&just=fine", // "PATCH=/second/equals?are=handled&just=fine",
}, // },
expectedRoutes: []expectedAllowedRoute{ // expectedRoutes: []expectedAllowedRoute{
{ // {
method: "GET", // method: "GET",
regexString: "^/foo/bar", // regexString: "^/foo/bar",
}, // },
{ // {
method: "POST", // method: "POST",
regexString: "^/baz/[0-9]+/thing", // regexString: "^/baz/[0-9]+/thing",
}, // },
{ // {
method: "", // method: "",
regexString: "^/all/methods$", // regexString: "^/all/methods$",
}, // },
{ // {
method: "WEIRD", // method: "WEIRD",
regexString: "^/methods/are/allowed", // regexString: "^/methods/are/allowed",
}, // },
{ // {
method: "PATCH", // method: "PATCH",
regexString: "/second/equals?are=handled&just=fine", // regexString: "/second/equals?are=handled&just=fine",
}, // },
}, // },
shouldError: false, // shouldError: false,
}, // },
{ // {
name: "Both skipAuthRegexes and skipAuthRoutes configured", // name: "Both skipAuthRegexes and skipAuthRoutes configured",
skipAuthRegex: []string{ // skipAuthRegex: []string{
"^/foo/bar/regex", // "^/foo/bar/regex",
"^/baz/[0-9]+/thing/regex", // "^/baz/[0-9]+/thing/regex",
}, // },
skipAuthRoutes: []string{ // skipAuthRoutes: []string{
"GET=^/foo/bar", // "GET=^/foo/bar",
"POST=^/baz/[0-9]+/thing", // "POST=^/baz/[0-9]+/thing",
"^/all/methods$", // "^/all/methods$",
}, // },
expectedRoutes: []expectedAllowedRoute{ // expectedRoutes: []expectedAllowedRoute{
{ // {
method: "", // method: "",
regexString: "^/foo/bar/regex", // regexString: "^/foo/bar/regex",
}, // },
{ // {
method: "", // method: "",
regexString: "^/baz/[0-9]+/thing/regex", // regexString: "^/baz/[0-9]+/thing/regex",
}, // },
{ // {
method: "GET", // method: "GET",
regexString: "^/foo/bar", // regexString: "^/foo/bar",
}, // },
{ // {
method: "POST", // method: "POST",
regexString: "^/baz/[0-9]+/thing", // regexString: "^/baz/[0-9]+/thing",
}, // },
{ // {
method: "", // method: "",
regexString: "^/all/methods$", // regexString: "^/all/methods$",
}, // },
}, // },
shouldError: false, // shouldError: false,
}, // },
{ // {
name: "Invalid skipAuthRegex entry", // name: "Invalid skipAuthRegex entry",
skipAuthRegex: []string{ // skipAuthRegex: []string{
"^/foo/bar", // "^/foo/bar",
"^/baz/[0-9]+/thing", // "^/baz/[0-9]+/thing",
"(bad[regex", // "(bad[regex",
}, // },
skipAuthRoutes: []string{}, // skipAuthRoutes: []string{},
expectedRoutes: []expectedAllowedRoute{}, // expectedRoutes: []expectedAllowedRoute{},
shouldError: true, // shouldError: true,
}, // },
{ // {
name: "Invalid skipAuthRoutes entry", // name: "Invalid skipAuthRoutes entry",
skipAuthRegex: []string{}, // skipAuthRegex: []string{},
skipAuthRoutes: []string{ // skipAuthRoutes: []string{
"GET=^/foo/bar", // "GET=^/foo/bar",
"POST=^/baz/[0-9]+/thing", // "POST=^/baz/[0-9]+/thing",
"^/all/methods$", // "^/all/methods$",
"PUT=(bad[regex", // "PUT=(bad[regex",
}, // },
expectedRoutes: []expectedAllowedRoute{}, // expectedRoutes: []expectedAllowedRoute{},
shouldError: true, // shouldError: true,
}, // },
} // }
//
for _, tc := range testCases { // for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { // t.Run(tc.name, func(t *testing.T) {
opts := &options.Options{ // opts := &options.Options{
SkipAuthRegex: tc.skipAuthRegex, // SkipAuthRegex: tc.skipAuthRegex,
SkipAuthRoutes: tc.skipAuthRoutes, // SkipAuthRoutes: tc.skipAuthRoutes,
} // }
routes, err := buildRoutesAllowlist(opts) // routes, err := buildRoutesAllowlist(opts)
if tc.shouldError { // if tc.shouldError {
assert.Error(t, err) // assert.Error(t, err)
return // return
} // }
assert.NoError(t, err) // assert.NoError(t, err)
//
for i, route := range routes { // for i, route := range routes {
assert.Greater(t, len(tc.expectedRoutes), i) // assert.Greater(t, len(tc.expectedRoutes), i)
assert.Equal(t, route.method, tc.expectedRoutes[i].method) // assert.Equal(t, route.method, tc.expectedRoutes[i].method)
assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString) // assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString)
} // }
}) // })
} // }
} // }
func TestAllowedRequest(t *testing.T) { func TestAllowedRequest(t *testing.T) {
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -2356,12 +2378,20 @@ func TestAllowedRequest(t *testing.T) {
}, },
}, },
} }
opts.SkipAuthRegex = []string{ opts.Authorization.RequestRules = []options.AuthorizationRule{
"^/skip/auth/regex$", {
} ID: "regex",
opts.SkipAuthRoutes = []string{ Path: "^/skip/auth/regex$",
"GET=^/skip/auth/routes/get", Policy: options.AllowPolicy,
},
{
ID: "route",
Path: "^/skip/auth/routes/get",
Methods: []string{http.MethodGet},
Policy: options.AllowPolicy,
},
} }
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
@ -2417,7 +2447,6 @@ func TestAllowedRequest(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest(tc.method, tc.url, nil) req, err := http.NewRequest(tc.method, tc.url, nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req))
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
proxy.ServeHTTP(rw, req) proxy.ServeHTTP(rw, req)
@ -2670,8 +2699,18 @@ func TestAuthOnlyAllowedGroupsWithSkipMethods(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
test, err := NewAuthOnlyEndpointTest("?allowed_groups=a,b", func(opts *options.Options) { test, err := NewAuthOnlyEndpointTest("?allowed_groups=a,b", func(opts *options.Options) {
opts.SkipAuthPreflight = true opts.Authorization.RequestRules = []options.AuthorizationRule{
opts.TrustedIPs = []string{"1.2.3.4"} {
ID: "skip-auth-preflight",
Methods: []string{http.MethodOptions},
Policy: options.AllowPolicy,
},
{
ID: "trusted-ips",
IPs: []string{"1.2.3.4"},
Policy: options.AllowPolicy,
},
}
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -4,32 +4,24 @@ import (
"net" "net"
"net/http" "net/http"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
) )
type AuthorizationPolicy int
const (
NonePolicy AuthorizationPolicy = iota
AllowPolicy
DelegatePolicy
DenyPolicy
)
type RuleSet interface { type RuleSet interface {
MatchesRequest(req *http.Request) AuthorizationPolicy MatchesRequest(req *http.Request) middlewareapi.AuthorizationPolicy
} }
type rule struct { type rule struct {
conditions []condition conditions []condition
policy AuthorizationPolicy policy middlewareapi.AuthorizationPolicy
} }
func (r rule) matches(req *http.Request) AuthorizationPolicy { func (r rule) matches(req *http.Request) middlewareapi.AuthorizationPolicy {
for _, condition := range r.conditions { for _, condition := range r.conditions {
if !condition.matches(req) { if !condition.matches(req) {
// One of the conditions didn't match so this rule does not apply // One of the conditions didn't match so this rule does not apply
return NonePolicy return middlewareapi.OmittedPolicy
} }
} }
// If all conditions match, return the configured rule policy // If all conditions match, return the configured rule policy
@ -60,17 +52,17 @@ func newRule(authRule options.AuthorizationRule, getClientIPFunc func(*http.Requ
conditions = append(conditions, condition) conditions = append(conditions, condition)
} }
var policy AuthorizationPolicy var policy middlewareapi.AuthorizationPolicy
switch authRule.Policy { switch authRule.Policy {
case options.AllowPolicy: case options.AllowPolicy:
policy = AllowPolicy policy = middlewareapi.AllowPolicy
case options.DelegatePolicy: case options.DelegatePolicy:
policy = DelegatePolicy policy = middlewareapi.DelegatePolicy
case options.DenyPolicy: case options.DenyPolicy:
policy = DenyPolicy policy = middlewareapi.DenyPolicy
default: default:
// This shouldn't be the case and should be prevented by validation // This shouldn't be the case and should be prevented by validation
policy = NonePolicy policy = middlewareapi.OmittedPolicy
} }
return rule{ return rule{
@ -83,15 +75,15 @@ type ruleSet struct {
rules []rule rules []rule
} }
func (r ruleSet) MatchesRequest(req *http.Request) AuthorizationPolicy { func (r ruleSet) MatchesRequest(req *http.Request) middlewareapi.AuthorizationPolicy {
for _, rule := range r.rules { for _, rule := range r.rules {
if policy := rule.matches(req); policy != NonePolicy { if policy := rule.matches(req); policy != middlewareapi.OmittedPolicy {
// The rule applies to this request, return its policy // The rule applies to this request, return its policy
return policy return policy
} }
} }
// No rules matched // No rules matched
return NonePolicy return middlewareapi.OmittedPolicy
} }
func NewRuleSet(requestRules []options.AuthorizationRule, getClientIPFunc func(*http.Request) net.IP) (RuleSet, error) { func NewRuleSet(requestRules []options.AuthorizationRule, getClientIPFunc func(*http.Request) net.IP) (RuleSet, error) {

View File

@ -6,10 +6,11 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
) )
var result AuthorizationPolicy var result middlewareapi.AuthorizationPolicy
func benchmarkRuleSetMatches(ruleCount int, b *testing.B) { func benchmarkRuleSetMatches(ruleCount int, b *testing.B) {
rule1 := options.AuthorizationRule{ rule1 := options.AuthorizationRule{
@ -53,10 +54,10 @@ func benchmarkRuleSetMatches(ruleCount int, b *testing.B) {
req := httptest.NewRequest("GET", "/foo/bar/baz", nil) req := httptest.NewRequest("GET", "/foo/bar/baz", nil)
var r AuthorizationPolicy var r middlewareapi.AuthorizationPolicy
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
r = ruleSet.MatchesRequest(req) r = ruleSet.MatchesRequest(req)
if r != NonePolicy { if r != middlewareapi.OmittedPolicy {
b.Fatal("expected policy not to match") b.Fatal("expected policy not to match")
} }
} }

View File

@ -0,0 +1,61 @@
package middleware
import (
"fmt"
"net"
"net/http"
"github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authorization"
)
func NewRequestAuthorization(writer pagewriter.Writer, requestRules []options.AuthorizationRule, getClientIPFunc func(*http.Request) net.IP) (alice.Constructor, error) {
ruleset, err := authorization.NewRuleSet(requestRules, getClientIPFunc)
if err != nil {
return nil, fmt.Errorf("could not initialise ruleset: %w", err)
}
ra := &requestAuthorizer{
ruleset: ruleset,
writer: writer,
}
return ra.checkRequestAuthorization, nil
}
type requestAuthorizer struct {
ruleset authorization.RuleSet
writer pagewriter.Writer
}
func (r *requestAuthorizer) checkRequestAuthorization(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic.
// A scope should always be injected before this handler is called.
if scope.Authorization.Policy != middlewareapi.OmittedPolicy {
// The request was already authorized, pass to the next handler
next.ServeHTTP(rw, req)
return
}
policy := r.ruleset.MatchesRequest(req)
switch policy {
case middlewareapi.AllowPolicy, middlewareapi.DelegatePolicy:
scope.Authorization.Type = middlewareapi.RequestAuthorization
scope.Authorization.Policy = policy
case middlewareapi.DenyPolicy:
r.writer.WriteErrorPage(rw, pagewriter.ErrorPageOpts{
Status: http.StatusForbidden,
RequestID: scope.RequestID,
AppError: "Request denied by authorization policy",
Messages: []interface{}{"Request denied by authorization policy"},
})
}
next.ServeHTTP(rw, req)
})
}

View File

@ -1,70 +0,0 @@
package validation
import (
"fmt"
"os"
"regexp"
"strings"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
)
func validateAllowlists(o *options.Options) []string {
msgs := []string{}
msgs = append(msgs, validateRoutes(o)...)
msgs = append(msgs, validateRegexes(o)...)
msgs = append(msgs, validateTrustedIPs(o)...)
if len(o.TrustedIPs) > 0 && o.ReverseProxy {
_, err := fmt.Fprintln(os.Stderr, "WARNING: mixing --trusted-ip with --reverse-proxy is a potential security vulnerability. An attacker can inject a trusted IP into an X-Real-IP or X-Forwarded-For header if they aren't properly protected outside of oauth2-proxy")
if err != nil {
panic(err)
}
}
return msgs
}
// validateRoutes validates method=path routes passed with options.SkipAuthRoutes
func validateRoutes(o *options.Options) []string {
msgs := []string{}
for _, route := range o.SkipAuthRoutes {
var regex string
parts := strings.SplitN(route, "=", 2)
if len(parts) == 1 {
regex = parts[0]
} else {
regex = parts[1]
}
_, err := regexp.Compile(regex)
if err != nil {
msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err))
}
}
return msgs
}
// validateRegex validates regex paths passed with options.SkipAuthRegex
func validateRegexes(o *options.Options) []string {
msgs := []string{}
for _, regex := range o.SkipAuthRegex {
_, err := regexp.Compile(regex)
if err != nil {
msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err))
}
}
return msgs
}
// validateTrustedIPs validates IP/CIDRs for IP based allowlists
func validateTrustedIPs(o *options.Options) []string {
msgs := []string{}
for i, ipStr := range o.TrustedIPs {
if nil == ip.ParseIPNet(ipStr) {
msgs = append(msgs, fmt.Sprintf("trusted_ips[%d] (%s) could not be recognized", i, ipStr))
}
}
return msgs
}

View File

@ -1,125 +1,125 @@
package validation package validation
import ( // import (
. "github.com/onsi/ginkgo" // . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table" // . "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega" // . "github.com/onsi/gomega"
//
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" // "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
) // )
//
var _ = Describe("Allowlist", func() { // var _ = Describe("Allowlist", func() {
type validateRoutesTableInput struct { // type validateRoutesTableInput struct {
routes []string // routes []string
errStrings []string // errStrings []string
} // }
//
type validateRegexesTableInput struct { // type validateRegexesTableInput struct {
regexes []string // regexes []string
errStrings []string // errStrings []string
} // }
//
type validateTrustedIPsTableInput struct { // type validateTrustedIPsTableInput struct {
trustedIPs []string // trustedIPs []string
errStrings []string // errStrings []string
} // }
//
DescribeTable("validateRoutes", // DescribeTable("validateRoutes",
func(r *validateRoutesTableInput) { // func(r *validateRoutesTableInput) {
opts := &options.Options{ // opts := &options.Options{
SkipAuthRoutes: r.routes, // SkipAuthRoutes: r.routes,
} // }
Expect(validateRoutes(opts)).To(ConsistOf(r.errStrings)) // Expect(validateRoutes(opts)).To(ConsistOf(r.errStrings))
}, // },
Entry("Valid regex routes", &validateRoutesTableInput{ // Entry("Valid regex routes", &validateRoutesTableInput{
routes: []string{ // routes: []string{
"/foo", // "/foo",
"POST=/foo/bar", // "POST=/foo/bar",
"PUT=^/foo/bar$", // "PUT=^/foo/bar$",
"DELETE=/crazy/(?:regex)?/[^/]+/stuff$", // "DELETE=/crazy/(?:regex)?/[^/]+/stuff$",
}, // },
errStrings: []string{}, // errStrings: []string{},
}), // }),
Entry("Bad regexes do not compile", &validateRoutesTableInput{ // Entry("Bad regexes do not compile", &validateRoutesTableInput{
routes: []string{ // routes: []string{
"POST=/(foo", // "POST=/(foo",
"OPTIONS=/foo/bar)", // "OPTIONS=/foo/bar)",
"GET=^]/foo/bar[$", // "GET=^]/foo/bar[$",
"GET=^]/foo/bar[$", // "GET=^]/foo/bar[$",
}, // },
errStrings: []string{ // errStrings: []string{
"error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`", // "error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`",
"error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`", // "error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`",
"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", // "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", // "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
}, // },
}), // }),
) // )
//
DescribeTable("validateRegexes", // DescribeTable("validateRegexes",
func(r *validateRegexesTableInput) { // func(r *validateRegexesTableInput) {
opts := &options.Options{ // opts := &options.Options{
SkipAuthRegex: r.regexes, // SkipAuthRegex: r.regexes,
} // }
Expect(validateRegexes(opts)).To(ConsistOf(r.errStrings)) // Expect(validateRegexes(opts)).To(ConsistOf(r.errStrings))
}, // },
Entry("Valid regex routes", &validateRegexesTableInput{ // Entry("Valid regex routes", &validateRegexesTableInput{
regexes: []string{ // regexes: []string{
"/foo", // "/foo",
"/foo/bar", // "/foo/bar",
"^/foo/bar$", // "^/foo/bar$",
"/crazy/(?:regex)?/[^/]+/stuff$", // "/crazy/(?:regex)?/[^/]+/stuff$",
}, // },
errStrings: []string{}, // errStrings: []string{},
}), // }),
Entry("Bad regexes do not compile", &validateRegexesTableInput{ // Entry("Bad regexes do not compile", &validateRegexesTableInput{
regexes: []string{ // regexes: []string{
"/(foo", // "/(foo",
"/foo/bar)", // "/foo/bar)",
"^]/foo/bar[$", // "^]/foo/bar[$",
"^]/foo/bar[$", // "^]/foo/bar[$",
}, // },
errStrings: []string{ // errStrings: []string{
"error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`", // "error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`",
"error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`", // "error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`",
"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", // "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", // "error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
}, // },
}), // }),
) // )
//
DescribeTable("validateTrustedIPs", // DescribeTable("validateTrustedIPs",
func(t *validateTrustedIPsTableInput) { // func(t *validateTrustedIPsTableInput) {
opts := &options.Options{ // opts := &options.Options{
TrustedIPs: t.trustedIPs, // TrustedIPs: t.trustedIPs,
} // }
Expect(validateTrustedIPs(opts)).To(ConsistOf(t.errStrings)) // Expect(validateTrustedIPs(opts)).To(ConsistOf(t.errStrings))
}, // },
Entry("Non-overlapping valid IPs", &validateTrustedIPsTableInput{ // Entry("Non-overlapping valid IPs", &validateTrustedIPsTableInput{
trustedIPs: []string{ // trustedIPs: []string{
"127.0.0.1", // "127.0.0.1",
"10.32.0.1/32", // "10.32.0.1/32",
"43.36.201.0/24", // "43.36.201.0/24",
"::1", // "::1",
"2a12:105:ee7:9234:0:0:0:0/64", // "2a12:105:ee7:9234:0:0:0:0/64",
}, // },
errStrings: []string{}, // errStrings: []string{},
}), // }),
Entry("Overlapping valid IPs", &validateTrustedIPsTableInput{ // Entry("Overlapping valid IPs", &validateTrustedIPsTableInput{
trustedIPs: []string{ // trustedIPs: []string{
"135.180.78.199", // "135.180.78.199",
"135.180.78.199/32", // "135.180.78.199/32",
"d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4", // "d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4",
"d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4/128", // "d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4/128",
}, // },
errStrings: []string{}, // errStrings: []string{},
}), // }),
Entry("Invalid IPs", &validateTrustedIPsTableInput{ // Entry("Invalid IPs", &validateTrustedIPsTableInput{
trustedIPs: []string{"[::1]", "alkwlkbn/32"}, // trustedIPs: []string{"[::1]", "alkwlkbn/32"},
errStrings: []string{ // errStrings: []string{
"trusted_ips[0] ([::1]) could not be recognized", // "trusted_ips[0] ([::1]) could not be recognized",
"trusted_ips[1] (alkwlkbn/32) could not be recognized", // "trusted_ips[1] (alkwlkbn/32) could not be recognized",
}, // },
}), // }),
) // )
}) // })

View File

@ -0,0 +1,94 @@
package validation
import (
"fmt"
"os"
"regexp"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
)
func validateAuthorization(authorization options.Authorization, reverseProxy bool) []string {
msgs := []string{}
msgs = append(msgs, validateRequestRules(authorization.RequestRules, reverseProxy)...)
return msgs
}
func validateRequestRules(rules []options.AuthorizationRule, reverseProxy bool) []string {
msgs := []string{}
ids := make(map[string]struct{})
for _, rule := range rules {
msgs = append(msgs, validateRequestRule(ids, rule, reverseProxy)...)
}
return msgs
}
func validateRequestRule(ids map[string]struct{}, rule options.AuthorizationRule, reverseProxy bool) []string {
msgs := []string{}
if rule.ID == "" {
msgs = append(msgs, "request rule has empty ID: IDs are required for all request rules")
}
if _, ok := ids[rule.ID]; ok {
msgs = append(msgs, fmt.Sprintf("multiple request rules found with ID %q: request rule IDs must be unique", rule.ID))
}
ids[rule.ID] = struct{}{}
msgs = append(msgs, validateRequestRulePolicy(rule.ID, rule.Policy)...)
msgs = append(msgs, validateRequestRulePath(rule.ID, rule.Path)...)
msgs = append(msgs, validateRequestRuleIPs(rule.ID, rule.IPs, reverseProxy)...)
return msgs
}
func validateRequestRulePolicy(ruleID string, policy options.AuthorizationPolicy) []string {
msgs := []string{}
switch policy {
case options.AllowPolicy, options.DenyPolicy, options.DelegatePolicy:
// Do nothing for valid options
default:
msgs = append(msgs, fmt.Sprintf("request rule %q has invalid policy (%s): policy must be one of %s, %s or %s", ruleID, policy, options.AllowPolicy, options.DenyPolicy, options.DelegatePolicy))
}
return msgs
}
// validateRequestRulePath validates paths for path/regex based conditions
func validateRequestRulePath(ruleID string, path string) []string {
msgs := []string{}
_, err := regexp.Compile(path)
if err != nil {
msgs = append(msgs, fmt.Sprintf("error compiling path regex (%s) for rule %q: %v", path, ruleID, err))
}
return msgs
}
// validateRequestRuleIPs validates IP/CIDRs for IP based conditions.
func validateRequestRuleIPs(ruleID string, ips []string, reverseProxy bool) []string {
msgs := []string{}
if len(ips) > 0 && reverseProxy {
_, err := fmt.Fprintln(os.Stderr, "WARNING: mixing IP authorization with --reverse-proxy is a potential security vulnerability. An attacker can inject a trusted IP into an X-Real-IP or X-Forwarded-For header if they aren't properly protected outside of oauth2-proxy")
if err != nil {
panic(err)
}
}
for i, ipStr := range ips {
if nil == ip.ParseIPNet(ipStr) {
msgs = append(msgs, fmt.Sprintf("rule %q IP [%d] (%s) could not be recognized", ruleID, i, ipStr))
}
}
return msgs
}

View File

@ -20,6 +20,7 @@ import (
// are of the correct format // are of the correct format
func Validate(o *options.Options) error { func Validate(o *options.Options) error {
msgs := validateCookie(o.Cookie) msgs := validateCookie(o.Cookie)
msgs = append(msgs, validateAuthorization(o.Authorization, o.ReverseProxy)...)
msgs = append(msgs, validateSessionCookieMinimal(o)...) msgs = append(msgs, validateSessionCookieMinimal(o)...)
msgs = append(msgs, validateRedisSessionStore(o)...) msgs = append(msgs, validateRedisSessionStore(o)...)
msgs = append(msgs, prefixValues("injectRequestHeaders: ", validateHeaders(o.InjectRequestHeaders)...)...) msgs = append(msgs, prefixValues("injectRequestHeaders: ", validateHeaders(o.InjectRequestHeaders)...)...)
@ -96,9 +97,6 @@ func Validate(o *options.Options) error {
}) })
} }
// Do this after ReverseProxy validation for TrustedIP coordinated checks
msgs = append(msgs, validateAllowlists(o)...)
if len(msgs) != 0 { if len(msgs) != 0 {
return fmt.Errorf("invalid configuration:\n %s", return fmt.Errorf("invalid configuration:\n %s",
strings.Join(msgs, "\n ")) strings.Join(msgs, "\n "))