From fa6a785eafa29ecb32764c52ca75807f501cf090 Mon Sep 17 00:00:00 2001 From: Nick Meves <nick.meves@greenhouse.io> Date: Sat, 2 Jan 2021 14:20:48 -0800 Subject: [PATCH] Improve handler vs helper organization in oauthproxy.go Additionally, convert a lot of helper methods to be private --- oauthproxy.go | 652 +++++++++++++++++++++++---------------------- oauthproxy_test.go | 9 +- 2 files changed, 332 insertions(+), 329 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index a595bc3b..28f667b3 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -31,9 +31,7 @@ import ( ) const ( - httpScheme = "http" - httpsScheme = "https" - + schemeHTTPS = "https" applicationJSON = "application/json" ) @@ -366,49 +364,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { return routes, nil } -// GetOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will -// redirect clients to once authenticated -func (p *OAuthProxy) GetOAuthRedirectURI(host string) string { - // default to the request Host if not set - if p.redirectURL.Host != "" { - return p.redirectURL.String() - } - u := *p.redirectURL - if u.Scheme == "" { - if p.CookieSecure { - u.Scheme = httpsScheme - } else { - u.Scheme = httpScheme - } - } - u.Host = host - return u.String() -} - -func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { - if code == "" { - return nil, providers.ErrMissingCode - } - redirectURI := p.GetOAuthRedirectURI(host) - s, err := p.provider.Redeem(ctx, redirectURI, code) - if err != nil { - return nil, err - } - return s, nil -} - -func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error { - var err error - if s.Email == "" { - s.Email, err = p.provider.GetEmailAddress(ctx, s) - if err != nil && !errors.Is(err, providers.ErrNotImplemented) { - return err - } - } - - return p.provider.EnrichSession(ctx, s) -} - // MakeCSRFCookie creates a cookie for CSRF func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) @@ -466,6 +421,81 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s return p.sessionStore.Save(rw, req, s) } +// IsValidRedirect checks whether the redirect URL is whitelisted +func (p *OAuthProxy) IsValidRedirect(redirect string) bool { + switch { + case redirect == "": + // The user didn't specify a redirect, should fallback to `/` + return false + case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect): + return true + case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): + redirectURL, err := url.Parse(redirect) + if err != nil { + logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) + return false + } + redirectHostname := redirectURL.Hostname() + + for _, domain := range p.whitelistDomains { + domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, ".")) + if domainHostname == "" { + continue + } + + if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) { + // the domain names match, now validate the ports + // if the whitelisted domain's port is '*', allow all ports + // if the whitelisted domain contains a specific port, only allow that port + // if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https + redirectPort := redirectURL.Port() + if (domainPort == "*") || + (domainPort == redirectPort) || + (domainPort == "" && redirectPort == "") { + return true + } + } + } + + logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) + return false + default: + logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) + return false + } +} + +func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) +} + +func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { + if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { + prepareNoCache(rw) + } + + switch path := req.URL.Path; { + case path == p.RobotsPath: + p.RobotsTxt(rw) + case p.IsAllowedRequest(req): + p.SkipAuthProxy(rw, req) + case path == p.SignInPath: + p.SignIn(rw, req) + case path == p.SignOutPath: + p.SignOut(rw, req) + case path == p.OAuthStartPath: + p.OAuthStart(rw, req) + case path == p.OAuthCallbackPath: + p.OAuthCallback(rw, req) + case path == p.AuthOnlyPath: + p.AuthOnly(rw, req) + case path == p.UserInfoPath: + p.UserInfo(rw, req) + default: + p.Proxy(rw, req) + } +} + // RobotsTxt disallows scraping pages from the OAuthProxy func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { _, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") @@ -496,6 +526,42 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m } } +// IsAllowedRequest is used to check if auth should be skipped for this request +func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { + isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" + return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req) +} + +// IsAllowedRoute is used to check if the request method & path is allowed without auth +func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { + for _, route := range p.allowedRoutes { + if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { + return true + } + } + return false +} + +// isTrustedIP is used to check if a request comes from a trusted client IP address. +func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { + if p.trustedIPs == nil { + return false + } + + remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) + if err != nil { + logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) + // Possibly spoofed X-Real-IP header + return false + } + + if remoteAddr == nil { + return false + } + + return p.trustedIPs.Has(remoteAddr) +} + // SignInPage writes the sing in template to the response func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { prepareNoCache(rw) @@ -507,7 +573,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code } rw.WriteHeader(code) - redirectURL, err := p.GetAppRedirect(req) + redirectURL, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -566,276 +632,9 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) { return "", false } -// GetAppRedirect determines the full URL or URI path to redirect clients to -// once authenticated with the OAuthProxy -// Strategy priority (first legal result is used): -// - `rd` querysting parameter -// - `X-Auth-Request-Redirect` header -// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) -// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) -// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) -// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) -// - `/` -func (p *OAuthProxy) GetAppRedirect(req *http.Request) (string, error) { - err := req.ParseForm() - if err != nil { - return "", err - } - - // These redirect getter functions are strategies ordered by priority - // for figuring out the redirect URL. - type redirectGetter func(req *http.Request) string - for _, rdGetter := range []redirectGetter{ - p.getRdQuerystringRedirect, - p.getXAuthRequestRedirect, - p.getXForwardedHeadersRedirect, - p.getURIRedirect, - } { - if redirect := rdGetter(req); redirect != "" { - return redirect, nil - } - } - - return "/", nil -} - -func isForwardedRequest(req *http.Request) bool { - return requestutil.IsProxied(req) && - req.Host != requestutil.GetRequestHost(req) -} - -func (p *OAuthProxy) hasProxyPrefix(path string) bool { - return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) -} - -// getRdQuerystringRedirect handles this GetAppRedirect strategy: -// - `rd` querysting parameter -func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { - redirect := req.Form.Get("rd") - if p.IsValidRedirect(redirect) { - return redirect - } - return "" -} - -// getXAuthRequestRedirect handles this GetAppRedirect strategy: -// - `X-Auth-Request-Redirect` Header -func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { - redirect := req.Header.Get("X-Auth-Request-Redirect") - if p.IsValidRedirect(redirect) { - return redirect - } - return "" -} - -// getXForwardedHeadersRedirect handles these GetAppRedirect strategies: -// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) -// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) -func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { - if !isForwardedRequest(req) { - return "" - } - - uri := requestutil.GetRequestURI(req) - if p.hasProxyPrefix(uri) { - uri = "/" - } - - redirect := fmt.Sprintf( - "%s://%s%s", - requestutil.GetRequestProto(req), - requestutil.GetRequestHost(req), - uri, - ) - - if p.IsValidRedirect(redirect) { - return redirect - } - return "" -} - -// getURIRedirect handles these GetAppRedirect strategies: -// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) -// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) -// - `/` -func (p *OAuthProxy) getURIRedirect(req *http.Request) string { - redirect := requestutil.GetRequestURI(req) - if !p.IsValidRedirect(redirect) { - redirect = req.URL.RequestURI() - } - - if p.hasProxyPrefix(redirect) { - return "/" - } - return redirect -} - -// splitHostPort separates host and port. If the port is not valid, it returns -// the entire input as host, and it doesn't check the validity of the host. -// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. -// *** taken from net/url, modified validOptionalPort() to accept ":*" -func splitHostPort(hostport string) (host, port string) { - host = hostport - - colon := strings.LastIndexByte(host, ':') - if colon != -1 && validOptionalPort(host[colon:]) { - host, port = host[:colon], host[colon+1:] - } - - if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { - host = host[1 : len(host)-1] - } - - return -} - -// validOptionalPort reports whether port is either an empty string -// or matches /^:\d*$/ -// *** taken from net/url, modified to accept ":*" -func validOptionalPort(port string) bool { - if port == "" || port == ":*" { - return true - } - if port[0] != ':' { - return false - } - for _, b := range port[1:] { - if b < '0' || b > '9' { - return false - } - } - return true -} - -// IsValidRedirect checks whether the redirect URL is whitelisted -func (p *OAuthProxy) IsValidRedirect(redirect string) bool { - switch { - case redirect == "": - // The user didn't specify a redirect, should fallback to `/` - return false - case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect): - return true - case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): - redirectURL, err := url.Parse(redirect) - if err != nil { - logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) - return false - } - redirectHostname := redirectURL.Hostname() - - for _, domain := range p.whitelistDomains { - domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, ".")) - if domainHostname == "" { - continue - } - - if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) { - // the domain names match, now validate the ports - // if the whitelisted domain's port is '*', allow all ports - // if the whitelisted domain contains a specific port, only allow that port - // if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https - redirectPort := redirectURL.Port() - if (domainPort == "*") || - (domainPort == redirectPort) || - (domainPort == "" && redirectPort == "") { - return true - } - } - } - - logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) - return false - default: - logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) - return false - } -} - -// IsAllowedRequest is used to check if auth should be skipped for this request -func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { - isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" - return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.IsTrustedIP(req) -} - -// IsAllowedRoute is used to check if the request method & path is allowed without auth -func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { - for _, route := range p.allowedRoutes { - if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { - return true - } - } - return false -} - -// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en -var noCacheHeaders = map[string]string{ - "Expires": time.Unix(0, 0).Format(time.RFC1123), - "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", - "X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ -} - -// prepareNoCache prepares headers for preventing browser caching. -func prepareNoCache(w http.ResponseWriter) { - // Set NoCache headers - for k, v := range noCacheHeaders { - w.Header().Set(k, v) - } -} - -// IsTrustedIP is used to check if a request comes from a trusted client IP address. -func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool { - if p.trustedIPs == nil { - return false - } - - remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) - if err != nil { - logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) - // Possibly spoofed X-Real-IP header - return false - } - - if remoteAddr == nil { - return false - } - - return p.trustedIPs.Has(remoteAddr) -} - -func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) -} - -func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { - if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { - prepareNoCache(rw) - } - - switch path := req.URL.Path; { - case path == p.RobotsPath: - p.RobotsTxt(rw) - case p.IsAllowedRequest(req): - p.SkipAuthProxy(rw, req) - case path == p.SignInPath: - p.SignIn(rw, req) - case path == p.SignOutPath: - p.SignOut(rw, req) - case path == p.OAuthStartPath: - p.OAuthStart(rw, req) - case path == p.OAuthCallbackPath: - p.OAuthCallback(rw, req) - case path == p.AuthOnlyPath: - p.AuthOnly(rw, req) - case path == p.UserInfoPath: - p.UserInfo(rw, req) - default: - p.Proxy(rw, req) - } -} - // SignIn serves a page prompting users to sign in func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { - redirect, err := p.GetAppRedirect(req) + redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -893,7 +692,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { // SignOut sends a response to clear the authentication cookie func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { - redirect, err := p.GetAppRedirect(req) + redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -918,13 +717,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { return } p.SetCSRFCookie(rw, req, nonce) - redirect, err := p.GetAppRedirect(req) + redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } - redirectURI := p.GetOAuthRedirectURI(requestutil.GetRequestHost(req)) + redirectURI := p.getOAuthRedirectURI(req) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) } @@ -947,7 +746,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - session, err := p.redeemCode(req.Context(), requestutil.GetRequestHost(req), req.Form.Get("code")) + session, err := p.redeemCode(req) if err != nil { logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") @@ -1006,6 +805,32 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { } } +func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) { + code := req.Form.Get("code") + if code == "" { + return nil, providers.ErrMissingCode + } + + redirectURI := p.getOAuthRedirectURI(req) + s, err := p.provider.Redeem(req.Context(), redirectURI, code) + if err != nil { + return nil, err + } + return s, nil +} + +func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error { + var err error + if s.Email == "" { + s.Email, err = p.provider.GetEmailAddress(ctx, s) + if err != nil && !errors.Is(err, providers.ErrNotImplemented) { + return err + } + } + + return p.provider.EnrichSession(ctx, s) +} + // AuthOnly checks whether the user is currently logged in (both authentication // and optional authorization). func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { @@ -1023,7 +848,7 @@ func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { } // we are authenticated - p.addHeadersForProxying(rw, req, session) + p.addHeadersForProxying(rw, session) p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusAccepted) })).ServeHTTP(rw, req) @@ -1041,13 +866,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { switch err { case nil: // we are authenticated - p.addHeadersForProxying(rw, req, session) + p.addHeadersForProxying(rw, session) p.headersChain.Then(p.serveMux).ServeHTTP(rw, req) case ErrNeedsLogin: // we need to send the user to a login screen if isAjax(req) { // no point redirecting an AJAX request - p.ErrorJSON(rw, http.StatusUnauthorized) + p.errorJSON(rw, http.StatusUnauthorized) return } @@ -1066,7 +891,184 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { p.ErrorPage(rw, http.StatusInternalServerError, "Internal Error", "Internal Error") } +} +// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en +var noCacheHeaders = map[string]string{ + "Expires": time.Unix(0, 0).Format(time.RFC1123), + "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", + "X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ +} + +// prepareNoCache prepares headers for preventing browser caching. +func prepareNoCache(w http.ResponseWriter) { + // Set NoCache headers + for k, v := range noCacheHeaders { + w.Header().Set(k, v) + } +} + +// getOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will +// redirect clients to once authenticated. +// This is usually the OAuthProxy callback URL. +func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string { + // if `p.redirectURL` already has a host, return it + if p.redirectURL.Host != "" { + return p.redirectURL.String() + } + + // Otherwise figure out the scheme + host from the request + rd := *p.redirectURL + rd.Host = requestutil.GetRequestHost(req) + rd.Scheme = requestutil.GetRequestProto(req) + + // If CookieSecure is true, return `https` no matter what + // Not all reverse proxies set X-Forwarded-Proto + if p.CookieSecure { + rd.Scheme = schemeHTTPS + } + return rd.String() +} + +// getAppRedirect determines the full URL or URI path to redirect clients to +// once authenticated with the OAuthProxy +// Strategy priority (first legal result is used): +// - `rd` querysting parameter +// - `X-Auth-Request-Redirect` header +// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) +// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) +// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) +// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) +// - `/` +func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) { + err := req.ParseForm() + if err != nil { + return "", err + } + + // These redirect getter functions are strategies ordered by priority + // for figuring out the redirect URL. + type redirectGetter func(req *http.Request) string + for _, rdGetter := range []redirectGetter{ + p.getRdQuerystringRedirect, + p.getXAuthRequestRedirect, + p.getXForwardedHeadersRedirect, + p.getURIRedirect, + } { + if redirect := rdGetter(req); redirect != "" { + return redirect, nil + } + } + + return "/", nil +} + +func isForwardedRequest(req *http.Request) bool { + return requestutil.IsProxied(req) && + req.Host != requestutil.GetRequestHost(req) +} + +func (p *OAuthProxy) hasProxyPrefix(path string) bool { + return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) +} + +// getRdQuerystringRedirect handles this getAppRedirect strategy: +// - `rd` querysting parameter +func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { + redirect := req.Form.Get("rd") + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getXAuthRequestRedirect handles this getAppRedirect strategy: +// - `X-Auth-Request-Redirect` Header +func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { + redirect := req.Header.Get("X-Auth-Request-Redirect") + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getXForwardedHeadersRedirect handles these getAppRedirect strategies: +// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) +// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) +func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { + if !isForwardedRequest(req) { + return "" + } + + uri := requestutil.GetRequestURI(req) + if p.hasProxyPrefix(uri) { + uri = "/" + } + + redirect := fmt.Sprintf( + "%s://%s%s", + requestutil.GetRequestProto(req), + requestutil.GetRequestHost(req), + uri, + ) + + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getURIRedirect handles these getAppRedirect strategies: +// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) +// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) +// - `/` +func (p *OAuthProxy) getURIRedirect(req *http.Request) string { + redirect := requestutil.GetRequestURI(req) + if !p.IsValidRedirect(redirect) { + redirect = req.URL.RequestURI() + } + + if p.hasProxyPrefix(redirect) { + return "/" + } + return redirect +} + +// splitHostPort separates host and port. If the port is not valid, it returns +// the entire input as host, and it doesn't check the validity of the host. +// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. +// *** taken from net/url, modified validOptionalPort() to accept ":*" +func splitHostPort(hostport string) (host, port string) { + host = hostport + + colon := strings.LastIndexByte(host, ':') + if colon != -1 && validOptionalPort(host[colon:]) { + host, port = host[:colon], host[colon+1:] + } + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + + return +} + +// validOptionalPort reports whether port is either an empty string +// or matches /^:\d*$/ +// *** taken from net/url, modified to accept ":*" +func validOptionalPort(port string) bool { + if port == "" || port == ":*" { + return true + } + if port[0] != ':' { + return false + } + for _, b := range port[1:] { + if b < '0' || b > '9' { + return false + } + } + return true } // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so @@ -1153,7 +1155,7 @@ func extractAllowedGroups(req *http.Request) map[string]struct{} { } // addHeadersForProxying adds the appropriate headers the request / response for proxying -func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) { +func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) { if session.Email == "" { rw.Header().Set("GAP-Auth", session.User) } else { @@ -1181,8 +1183,8 @@ func isAjax(req *http.Request) bool { return false } -// ErrorJSON returns the error code with an application/json mime type -func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { +// errorJSON returns the error code with an application/json mime type +func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) { rw.Header().Set("Content-Type", applicationJSON) rw.WriteHeader(code) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 8adea1ce..3366ef5f 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -415,8 +415,9 @@ func Test_redeemCode(t *testing.T) { t.Fatal(err) } - _, err = proxy.redeemCode(context.Background(), "www.example.com", "") - assert.Error(t, err) + req := httptest.NewRequest(http.MethodGet, "/", nil) + _, err = proxy.redeemCode(req) + assert.Equal(t, providers.ErrMissingCode, err) } func Test_enrichSession(t *testing.T) { @@ -1749,7 +1750,7 @@ func TestRequestSignature(t *testing.T) { } } -func TestGetRedirect(t *testing.T) { +func Test_getAppRedirect(t *testing.T) { opts := baseTestOptions() opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443") err := validation.Validate(opts) @@ -1900,7 +1901,7 @@ func TestGetRedirect(t *testing.T) { req = middleware.AddRequestScope(req, &middleware.RequestScope{ ReverseProxy: tt.reverseProxy, }) - redirect, err := proxy.GetAppRedirect(req) + redirect, err := proxy.getAppRedirect(req) assert.NoError(t, err) assert.Equal(t, tt.expectedRedirect, redirect)