diff --git a/CHANGELOG.md b/CHANGELOG.md index bd2b9348..727d7c2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ - [#432](https://github.com/oauth2-proxy/oauth2-proxy/pull/432) Update ruby dependencies for documentation (@theobarberbany) - [#471](https://github.com/oauth2-proxy/oauth2-proxy/pull/471) Add logging in case of invalid redirects (@gargath) - [#462](https://github.com/oauth2-proxy/oauth2-proxy/pull/462) Allow HTML in banner message (@eritikass). +- [#412](https://github.com/pusher/oauth2_proxy/pull/412) Allow multiple cookie domains to be specified (@edahlseng) - [#413](https://github.com/oauth2-proxy/oauth2-proxy/pull/413) Add -set-basic-auth param to set the Basic Authorization header for upstreams (@morarucostel). # v5.1.0 diff --git a/docs/configuration/configuration.md b/docs/configuration/configuration.md index cb2db72c..0413adec 100644 --- a/docs/configuration/configuration.md +++ b/docs/configuration/configuration.md @@ -33,7 +33,7 @@ An example [oauth2-proxy.cfg]({{ site.gitweb }}/contrib/oauth2-proxy.cfg.example | `-client-secret` | string | the OAuth Client Secret | | | `-client-secret-file` | string | the file with OAuth Client Secret | | | `-config` | string | path to config file | | -| `-cookie-domain` | string | an optional cookie domain to force cookies to (ie: `.yourcompany.com`) | | +| `-cookie-domain` | string \| list | Optional cookie domains to force cookies to (ie: `.yourcompany.com`). The longest domain matching the request's host will be used (or the shortest cookie domain if there is no match). | | | `-cookie-expire` | duration | expire timeframe for cookie | 168h0m0s | | `-cookie-httponly` | bool | set HttpOnly cookie flag | true | | `-cookie-name` | string | the name of the cookie that the oauth_proxy creates | `"_oauth2_proxy"` | diff --git a/main.go b/main.go index bcd6bec5..2d1045eb 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,7 @@ func main() { logger.SetFlags(logger.Lshortfile) flagSet := flag.NewFlagSet("oauth2-proxy", flag.ExitOnError) + cookieDomains := StringArray{} emailDomains := StringArray{} whitelistDomains := StringArray{} upstreams := StringArray{} @@ -87,7 +88,7 @@ func main() { flagSet.String("cookie-name", "_oauth2_proxy", "the name of the cookie that the oauth_proxy creates") flagSet.String("cookie-secret", "", "the seed string for secure cookies (optionally base64 encoded)") - flagSet.String("cookie-domain", "", "an optional cookie domain to force cookies to (ie: .yourcompany.com)*") + flagSet.Var(&cookieDomains, "cookie-domain", "Optional cookie domains to force cookies to (ie: `.yourcompany.com`). The longest domain matching the request's host will be used (or the shortest cookie domain if there is no match).") flagSet.String("cookie-path", "/", "an optional cookie path to force cookies to (ie: /poc/)*") flagSet.Duration("cookie-expire", time.Duration(168)*time.Hour, "expire timeframe for cookie") flagSet.Duration("cookie-refresh", time.Duration(0), "refresh the cookie after this duration; 0 to disable") diff --git a/oauthproxy.go b/oauthproxy.go index dbe379a4..1d49b597 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -64,7 +64,7 @@ type OAuthProxy struct { CookieSeed string CookieName string CSRFCookieName string - CookieDomain string + CookieDomains []string CookiePath string CookieSecure bool CookieHTTPOnly bool @@ -265,13 +265,13 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { refresh = fmt.Sprintf("after %s", opts.CookieRefresh) } - logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s path:%s samesite:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, opts.CookiePath, opts.CookieSameSite, refresh) + logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domains:%s path:%s samesite:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, strings.Join(opts.CookieDomains, ","), opts.CookiePath, opts.CookieSameSite, refresh) return &OAuthProxy{ CookieName: opts.CookieName, CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"), CookieSeed: opts.CookieSecret, - CookieDomain: opts.CookieDomain, + CookieDomains: opts.CookieDomains, CookiePath: opts.CookiePath, CookieSecure: opts.CookieSecure, CookieHTTPOnly: opts.CookieHTTPOnly, @@ -377,13 +377,15 @@ func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration } func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { - if p.CookieDomain != "" { - domain := req.Host + cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains) + + if cookieDomain != "" { + domain := cookies.GetRequestHost(req) if h, _, err := net.SplitHostPort(domain); err == nil { domain = h } - if !strings.HasSuffix(domain, p.CookieDomain) { - logger.Printf("Warning: request host is %q but using configured cookie domain of %q", domain, p.CookieDomain) + if !strings.HasSuffix(domain, cookieDomain) { + logger.Printf("Warning: request host is %q but using configured cookie domain of %q", domain, cookieDomain) } } @@ -391,7 +393,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex Name: name, Value: value, Path: p.CookiePath, - Domain: p.CookieDomain, + Domain: cookieDomain, HttpOnly: p.CookieHTTPOnly, Secure: p.CookieSecure, Expires: now.Add(expiration), diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 86cfc90a..87668f10 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1405,10 +1405,10 @@ func TestAjaxForbiddendRequest(t *testing.T) { func TestClearSplitCookie(t *testing.T) { opts := NewOptions() opts.CookieName = "oauth2" - opts.CookieDomain = "abc" + opts.CookieDomains = []string{"abc"} store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions) assert.Equal(t, err, nil) - p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} + p := OAuthProxy{CookieName: opts.CookieName, CookieDomains: opts.CookieDomains, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) @@ -1434,10 +1434,10 @@ func TestClearSplitCookie(t *testing.T) { func TestClearSingleCookie(t *testing.T) { opts := NewOptions() opts.CookieName = "oauth2" - opts.CookieDomain = "abc" + opts.CookieDomains = []string{"abc"} store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions) assert.Equal(t, err, nil) - p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} + p := OAuthProxy{CookieName: opts.CookieName, CookieDomains: opts.CookieDomains, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) @@ -1530,7 +1530,7 @@ func TestGetJwtSession(t *testing.T) { } func TestFindJwtBearerToken(t *testing.T) { - p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} + p := OAuthProxy{CookieName: "oauth2", CookieDomains: []string{"abc"}} getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}} validToken := "eyJfoobar.eyJfoobar.12345asdf" diff --git a/options.go b/options.go index abfc6013..1a517b41 100644 --- a/options.go +++ b/options.go @@ -11,6 +11,7 @@ import ( "net/url" "os" "regexp" + "sort" "strings" "time" @@ -402,6 +403,12 @@ func (o *Options) Validate() error { msgs = append(msgs, fmt.Sprintf("cookie_samesite (%s) must be one of ['', 'lax', 'strict', 'none']", o.CookieSameSite)) } + // Sort cookie domains by length, so that we try longer (and more specific) + // domains first + sort.Slice(o.CookieDomains, func(i, j int) bool { + return len(o.CookieDomains[i]) > len(o.CookieDomains[j]) + }) + msgs = parseSignatureKey(o, msgs) msgs = validateCookieName(o, msgs) msgs = setupLogger(o, msgs) diff --git a/pkg/apis/options/cookie.go b/pkg/apis/options/cookie.go index dcb6c75a..4d267731 100644 --- a/pkg/apis/options/cookie.go +++ b/pkg/apis/options/cookie.go @@ -6,7 +6,7 @@ import "time" type CookieOptions struct { CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` - CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` + CookieDomains []string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` CookiePath string `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"` CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` diff --git a/pkg/cookies/cookies.go b/pkg/cookies/cookies.go index c4dd1675..7e31464e 100644 --- a/pkg/cookies/cookies.go +++ b/pkg/cookies/cookies.go @@ -39,7 +39,39 @@ func MakeCookie(req *http.Request, name string, value string, path string, domai // MakeCookieFromOptions constructs a cookie based on the given *options.CookieOptions, // value and creation time func MakeCookieFromOptions(req *http.Request, name string, value string, opts *options.CookieOptions, expiration time.Duration, now time.Time) *http.Cookie { - return MakeCookie(req, name, value, opts.CookiePath, opts.CookieDomain, opts.CookieHTTPOnly, opts.CookieSecure, expiration, now, ParseSameSite(opts.CookieSameSite)) + domain := GetCookieDomain(req, opts.CookieDomains) + + if domain != "" { + return MakeCookie(req, name, value, opts.CookiePath, domain, opts.CookieHTTPOnly, opts.CookieSecure, expiration, now, ParseSameSite(opts.CookieSameSite)) + } + // If nothing matches, create the cookie with the shortest domain + logger.Printf("Warning: request host %q did not match any of the specific cookie domains of %q", GetRequestHost(req), strings.Join(opts.CookieDomains, ",")) + defaultDomain := "" + if len(opts.CookieDomains) > 0 { + defaultDomain = opts.CookieDomains[len(opts.CookieDomains)-1] + } + return MakeCookie(req, name, value, opts.CookiePath, defaultDomain, opts.CookieHTTPOnly, opts.CookieSecure, expiration, now, ParseSameSite(opts.CookieSameSite)) +} + +// GetCookieDomain returns the correct cookie domain given a list of domains +// by checking the X-Fowarded-Host and host header of an an http request +func GetCookieDomain(req *http.Request, cookieDomains []string) string { + host := GetRequestHost(req) + for _, domain := range cookieDomains { + if strings.HasSuffix(host, domain) { + return domain + } + } + return "" +} + +// GetRequestHost return the request host header or X-Forwarded-Host if present +func GetRequestHost(req *http.Request) string { + host := req.Header.Get("X-Forwarded-Host") + if host == "" { + host = req.Host + } + return host } // Parse a valid http.SameSite value from a user supplied string for use of making cookies. diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go index 0e4d62dd..14707a2c 100644 --- a/pkg/sessions/session_store_test.go +++ b/pkg/sessions/session_store_test.go @@ -63,7 +63,11 @@ var _ = Describe("NewSessionStore", func() { It("have the correct domain set", func() { for _, cookie := range cookies { - Expect(cookie.Domain).To(Equal(cookieOpts.CookieDomain)) + specifiedDomain := "" + if len(cookieOpts.CookieDomains) > 0 { + specifiedDomain = cookieOpts.CookieDomains[0] + } + Expect(cookie.Domain).To(Equal(specifiedDomain)) } }) @@ -343,7 +347,7 @@ var _ = Describe("NewSessionStore", func() { CookieRefresh: time.Duration(2) * time.Hour, CookieSecure: false, CookieHTTPOnly: false, - CookieDomain: "example.com", + CookieDomains: []string{"example.com"}, CookieSameSite: "strict", }