From fa6a785eafa29ecb32764c52ca75807f501cf090 Mon Sep 17 00:00:00 2001
From: Nick Meves <nick.meves@greenhouse.io>
Date: Sat, 2 Jan 2021 14:20:48 -0800
Subject: [PATCH] Improve handler vs helper organization in oauthproxy.go

Additionally, convert a lot of helper methods to be private
---
 oauthproxy.go      | 652 +++++++++++++++++++++++----------------------
 oauthproxy_test.go |   9 +-
 2 files changed, 332 insertions(+), 329 deletions(-)

diff --git a/oauthproxy.go b/oauthproxy.go
index a595bc3b..28f667b3 100644
--- a/oauthproxy.go
+++ b/oauthproxy.go
@@ -31,9 +31,7 @@ import (
 )
 
 const (
-	httpScheme  = "http"
-	httpsScheme = "https"
-
+	schemeHTTPS     = "https"
 	applicationJSON = "application/json"
 )
 
@@ -366,49 +364,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) {
 	return routes, nil
 }
 
-// GetOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
-// redirect clients to once authenticated
-func (p *OAuthProxy) GetOAuthRedirectURI(host string) string {
-	// default to the request Host if not set
-	if p.redirectURL.Host != "" {
-		return p.redirectURL.String()
-	}
-	u := *p.redirectURL
-	if u.Scheme == "" {
-		if p.CookieSecure {
-			u.Scheme = httpsScheme
-		} else {
-			u.Scheme = httpScheme
-		}
-	}
-	u.Host = host
-	return u.String()
-}
-
-func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
-	if code == "" {
-		return nil, providers.ErrMissingCode
-	}
-	redirectURI := p.GetOAuthRedirectURI(host)
-	s, err := p.provider.Redeem(ctx, redirectURI, code)
-	if err != nil {
-		return nil, err
-	}
-	return s, nil
-}
-
-func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
-	var err error
-	if s.Email == "" {
-		s.Email, err = p.provider.GetEmailAddress(ctx, s)
-		if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
-			return err
-		}
-	}
-
-	return p.provider.EnrichSession(ctx, s)
-}
-
 // MakeCSRFCookie creates a cookie for CSRF
 func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
 	return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
@@ -466,6 +421,81 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s
 	return p.sessionStore.Save(rw, req, s)
 }
 
+// IsValidRedirect checks whether the redirect URL is whitelisted
+func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
+	switch {
+	case redirect == "":
+		// The user didn't specify a redirect, should fallback to `/`
+		return false
+	case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
+		return true
+	case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
+		redirectURL, err := url.Parse(redirect)
+		if err != nil {
+			logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
+			return false
+		}
+		redirectHostname := redirectURL.Hostname()
+
+		for _, domain := range p.whitelistDomains {
+			domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
+			if domainHostname == "" {
+				continue
+			}
+
+			if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
+				// the domain names match, now validate the ports
+				// if the whitelisted domain's port is '*', allow all ports
+				// if the whitelisted domain contains a specific port, only allow that port
+				// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
+				redirectPort := redirectURL.Port()
+				if (domainPort == "*") ||
+					(domainPort == redirectPort) ||
+					(domainPort == "" && redirectPort == "") {
+					return true
+				}
+			}
+		}
+
+		logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
+		return false
+	default:
+		logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
+		return false
+	}
+}
+
+func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+	p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req)
+}
+
+func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
+	if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
+		prepareNoCache(rw)
+	}
+
+	switch path := req.URL.Path; {
+	case path == p.RobotsPath:
+		p.RobotsTxt(rw)
+	case p.IsAllowedRequest(req):
+		p.SkipAuthProxy(rw, req)
+	case path == p.SignInPath:
+		p.SignIn(rw, req)
+	case path == p.SignOutPath:
+		p.SignOut(rw, req)
+	case path == p.OAuthStartPath:
+		p.OAuthStart(rw, req)
+	case path == p.OAuthCallbackPath:
+		p.OAuthCallback(rw, req)
+	case path == p.AuthOnlyPath:
+		p.AuthOnly(rw, req)
+	case path == p.UserInfoPath:
+		p.UserInfo(rw, req)
+	default:
+		p.Proxy(rw, req)
+	}
+}
+
 // RobotsTxt disallows scraping pages from the OAuthProxy
 func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
 	_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
@@ -496,6 +526,42 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
 	}
 }
 
+// IsAllowedRequest is used to check if auth should be skipped for this request
+func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
+	isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
+	return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req)
+}
+
+// IsAllowedRoute is used to check if the request method & path is allowed without auth
+func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
+	for _, route := range p.allowedRoutes {
+		if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
+			return true
+		}
+	}
+	return false
+}
+
+// isTrustedIP is used to check if a request comes from a trusted client IP address.
+func (p *OAuthProxy) isTrustedIP(req *http.Request) bool {
+	if p.trustedIPs == nil {
+		return false
+	}
+
+	remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
+	if err != nil {
+		logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
+		// Possibly spoofed X-Real-IP header
+		return false
+	}
+
+	if remoteAddr == nil {
+		return false
+	}
+
+	return p.trustedIPs.Has(remoteAddr)
+}
+
 // SignInPage writes the sing in template to the response
 func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
 	prepareNoCache(rw)
@@ -507,7 +573,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
 	}
 	rw.WriteHeader(code)
 
-	redirectURL, err := p.GetAppRedirect(req)
+	redirectURL, err := p.getAppRedirect(req)
 	if err != nil {
 		logger.Errorf("Error obtaining redirect: %v", err)
 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
@@ -566,276 +632,9 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) {
 	return "", false
 }
 
-// GetAppRedirect determines the full URL or URI path to redirect clients to
-// once authenticated with the OAuthProxy
-// Strategy priority (first legal result is used):
-// - `rd` querysting parameter
-// - `X-Auth-Request-Redirect` header
-// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
-// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
-// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
-// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
-// - `/`
-func (p *OAuthProxy) GetAppRedirect(req *http.Request) (string, error) {
-	err := req.ParseForm()
-	if err != nil {
-		return "", err
-	}
-
-	// These redirect getter functions are strategies ordered by priority
-	// for figuring out the redirect URL.
-	type redirectGetter func(req *http.Request) string
-	for _, rdGetter := range []redirectGetter{
-		p.getRdQuerystringRedirect,
-		p.getXAuthRequestRedirect,
-		p.getXForwardedHeadersRedirect,
-		p.getURIRedirect,
-	} {
-		if redirect := rdGetter(req); redirect != "" {
-			return redirect, nil
-		}
-	}
-
-	return "/", nil
-}
-
-func isForwardedRequest(req *http.Request) bool {
-	return requestutil.IsProxied(req) &&
-		req.Host != requestutil.GetRequestHost(req)
-}
-
-func (p *OAuthProxy) hasProxyPrefix(path string) bool {
-	return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix))
-}
-
-// getRdQuerystringRedirect handles this GetAppRedirect strategy:
-// - `rd` querysting parameter
-func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string {
-	redirect := req.Form.Get("rd")
-	if p.IsValidRedirect(redirect) {
-		return redirect
-	}
-	return ""
-}
-
-// getXAuthRequestRedirect handles this GetAppRedirect strategy:
-// - `X-Auth-Request-Redirect` Header
-func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string {
-	redirect := req.Header.Get("X-Auth-Request-Redirect")
-	if p.IsValidRedirect(redirect) {
-		return redirect
-	}
-	return ""
-}
-
-// getXForwardedHeadersRedirect handles these GetAppRedirect strategies:
-// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
-// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
-func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string {
-	if !isForwardedRequest(req) {
-		return ""
-	}
-
-	uri := requestutil.GetRequestURI(req)
-	if p.hasProxyPrefix(uri) {
-		uri = "/"
-	}
-
-	redirect := fmt.Sprintf(
-		"%s://%s%s",
-		requestutil.GetRequestProto(req),
-		requestutil.GetRequestHost(req),
-		uri,
-	)
-
-	if p.IsValidRedirect(redirect) {
-		return redirect
-	}
-	return ""
-}
-
-// getURIRedirect handles these GetAppRedirect strategies:
-// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
-// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
-// - `/`
-func (p *OAuthProxy) getURIRedirect(req *http.Request) string {
-	redirect := requestutil.GetRequestURI(req)
-	if !p.IsValidRedirect(redirect) {
-		redirect = req.URL.RequestURI()
-	}
-
-	if p.hasProxyPrefix(redirect) {
-		return "/"
-	}
-	return redirect
-}
-
-// splitHostPort separates host and port. If the port is not valid, it returns
-// the entire input as host, and it doesn't check the validity of the host.
-// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
-// *** taken from net/url, modified validOptionalPort() to accept ":*"
-func splitHostPort(hostport string) (host, port string) {
-	host = hostport
-
-	colon := strings.LastIndexByte(host, ':')
-	if colon != -1 && validOptionalPort(host[colon:]) {
-		host, port = host[:colon], host[colon+1:]
-	}
-
-	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
-		host = host[1 : len(host)-1]
-	}
-
-	return
-}
-
-// validOptionalPort reports whether port is either an empty string
-// or matches /^:\d*$/
-// *** taken from net/url, modified to accept ":*"
-func validOptionalPort(port string) bool {
-	if port == "" || port == ":*" {
-		return true
-	}
-	if port[0] != ':' {
-		return false
-	}
-	for _, b := range port[1:] {
-		if b < '0' || b > '9' {
-			return false
-		}
-	}
-	return true
-}
-
-// IsValidRedirect checks whether the redirect URL is whitelisted
-func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
-	switch {
-	case redirect == "":
-		// The user didn't specify a redirect, should fallback to `/`
-		return false
-	case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
-		return true
-	case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
-		redirectURL, err := url.Parse(redirect)
-		if err != nil {
-			logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
-			return false
-		}
-		redirectHostname := redirectURL.Hostname()
-
-		for _, domain := range p.whitelistDomains {
-			domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
-			if domainHostname == "" {
-				continue
-			}
-
-			if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
-				// the domain names match, now validate the ports
-				// if the whitelisted domain's port is '*', allow all ports
-				// if the whitelisted domain contains a specific port, only allow that port
-				// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
-				redirectPort := redirectURL.Port()
-				if (domainPort == "*") ||
-					(domainPort == redirectPort) ||
-					(domainPort == "" && redirectPort == "") {
-					return true
-				}
-			}
-		}
-
-		logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
-		return false
-	default:
-		logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
-		return false
-	}
-}
-
-// IsAllowedRequest is used to check if auth should be skipped for this request
-func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
-	isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
-	return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.IsTrustedIP(req)
-}
-
-// IsAllowedRoute is used to check if the request method & path is allowed without auth
-func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
-	for _, route := range p.allowedRoutes {
-		if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
-			return true
-		}
-	}
-	return false
-}
-
-// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
-var noCacheHeaders = map[string]string{
-	"Expires":         time.Unix(0, 0).Format(time.RFC1123),
-	"Cache-Control":   "no-cache, no-store, must-revalidate, max-age=0",
-	"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
-}
-
-// prepareNoCache prepares headers for preventing browser caching.
-func prepareNoCache(w http.ResponseWriter) {
-	// Set NoCache headers
-	for k, v := range noCacheHeaders {
-		w.Header().Set(k, v)
-	}
-}
-
-// IsTrustedIP is used to check if a request comes from a trusted client IP address.
-func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool {
-	if p.trustedIPs == nil {
-		return false
-	}
-
-	remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
-	if err != nil {
-		logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
-		// Possibly spoofed X-Real-IP header
-		return false
-	}
-
-	if remoteAddr == nil {
-		return false
-	}
-
-	return p.trustedIPs.Has(remoteAddr)
-}
-
-func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
-	p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req)
-}
-
-func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
-	if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
-		prepareNoCache(rw)
-	}
-
-	switch path := req.URL.Path; {
-	case path == p.RobotsPath:
-		p.RobotsTxt(rw)
-	case p.IsAllowedRequest(req):
-		p.SkipAuthProxy(rw, req)
-	case path == p.SignInPath:
-		p.SignIn(rw, req)
-	case path == p.SignOutPath:
-		p.SignOut(rw, req)
-	case path == p.OAuthStartPath:
-		p.OAuthStart(rw, req)
-	case path == p.OAuthCallbackPath:
-		p.OAuthCallback(rw, req)
-	case path == p.AuthOnlyPath:
-		p.AuthOnly(rw, req)
-	case path == p.UserInfoPath:
-		p.UserInfo(rw, req)
-	default:
-		p.Proxy(rw, req)
-	}
-}
-
 // SignIn serves a page prompting users to sign in
 func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
-	redirect, err := p.GetAppRedirect(req)
+	redirect, err := p.getAppRedirect(req)
 	if err != nil {
 		logger.Errorf("Error obtaining redirect: %v", err)
 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
@@ -893,7 +692,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) {
 
 // SignOut sends a response to clear the authentication cookie
 func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
-	redirect, err := p.GetAppRedirect(req)
+	redirect, err := p.getAppRedirect(req)
 	if err != nil {
 		logger.Errorf("Error obtaining redirect: %v", err)
 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
@@ -918,13 +717,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
 		return
 	}
 	p.SetCSRFCookie(rw, req, nonce)
-	redirect, err := p.GetAppRedirect(req)
+	redirect, err := p.getAppRedirect(req)
 	if err != nil {
 		logger.Errorf("Error obtaining redirect: %v", err)
 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
 		return
 	}
-	redirectURI := p.GetOAuthRedirectURI(requestutil.GetRequestHost(req))
+	redirectURI := p.getOAuthRedirectURI(req)
 	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
 }
 
@@ -947,7 +746,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
 		return
 	}
 
-	session, err := p.redeemCode(req.Context(), requestutil.GetRequestHost(req), req.Form.Get("code"))
+	session, err := p.redeemCode(req)
 	if err != nil {
 		logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
@@ -1006,6 +805,32 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
 	}
 }
 
+func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) {
+	code := req.Form.Get("code")
+	if code == "" {
+		return nil, providers.ErrMissingCode
+	}
+
+	redirectURI := p.getOAuthRedirectURI(req)
+	s, err := p.provider.Redeem(req.Context(), redirectURI, code)
+	if err != nil {
+		return nil, err
+	}
+	return s, nil
+}
+
+func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
+	var err error
+	if s.Email == "" {
+		s.Email, err = p.provider.GetEmailAddress(ctx, s)
+		if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
+			return err
+		}
+	}
+
+	return p.provider.EnrichSession(ctx, s)
+}
+
 // AuthOnly checks whether the user is currently logged in (both authentication
 // and optional authorization).
 func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) {
@@ -1023,7 +848,7 @@ func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) {
 	}
 
 	// we are authenticated
-	p.addHeadersForProxying(rw, req, session)
+	p.addHeadersForProxying(rw, session)
 	p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 		rw.WriteHeader(http.StatusAccepted)
 	})).ServeHTTP(rw, req)
@@ -1041,13 +866,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
 	switch err {
 	case nil:
 		// we are authenticated
-		p.addHeadersForProxying(rw, req, session)
+		p.addHeadersForProxying(rw, session)
 		p.headersChain.Then(p.serveMux).ServeHTTP(rw, req)
 	case ErrNeedsLogin:
 		// we need to send the user to a login screen
 		if isAjax(req) {
 			// no point redirecting an AJAX request
-			p.ErrorJSON(rw, http.StatusUnauthorized)
+			p.errorJSON(rw, http.StatusUnauthorized)
 			return
 		}
 
@@ -1066,7 +891,184 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
 		p.ErrorPage(rw, http.StatusInternalServerError,
 			"Internal Error", "Internal Error")
 	}
+}
 
+// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
+var noCacheHeaders = map[string]string{
+	"Expires":         time.Unix(0, 0).Format(time.RFC1123),
+	"Cache-Control":   "no-cache, no-store, must-revalidate, max-age=0",
+	"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
+}
+
+// prepareNoCache prepares headers for preventing browser caching.
+func prepareNoCache(w http.ResponseWriter) {
+	// Set NoCache headers
+	for k, v := range noCacheHeaders {
+		w.Header().Set(k, v)
+	}
+}
+
+// getOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
+// redirect clients to once authenticated.
+// This is usually the OAuthProxy callback URL.
+func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string {
+	// if `p.redirectURL` already has a host, return it
+	if p.redirectURL.Host != "" {
+		return p.redirectURL.String()
+	}
+
+	// Otherwise figure out the scheme + host from the request
+	rd := *p.redirectURL
+	rd.Host = requestutil.GetRequestHost(req)
+	rd.Scheme = requestutil.GetRequestProto(req)
+
+	// If CookieSecure is true, return `https` no matter what
+	// Not all reverse proxies set X-Forwarded-Proto
+	if p.CookieSecure {
+		rd.Scheme = schemeHTTPS
+	}
+	return rd.String()
+}
+
+// getAppRedirect determines the full URL or URI path to redirect clients to
+// once authenticated with the OAuthProxy
+// Strategy priority (first legal result is used):
+// - `rd` querysting parameter
+// - `X-Auth-Request-Redirect` header
+// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
+// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
+// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
+// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
+// - `/`
+func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) {
+	err := req.ParseForm()
+	if err != nil {
+		return "", err
+	}
+
+	// These redirect getter functions are strategies ordered by priority
+	// for figuring out the redirect URL.
+	type redirectGetter func(req *http.Request) string
+	for _, rdGetter := range []redirectGetter{
+		p.getRdQuerystringRedirect,
+		p.getXAuthRequestRedirect,
+		p.getXForwardedHeadersRedirect,
+		p.getURIRedirect,
+	} {
+		if redirect := rdGetter(req); redirect != "" {
+			return redirect, nil
+		}
+	}
+
+	return "/", nil
+}
+
+func isForwardedRequest(req *http.Request) bool {
+	return requestutil.IsProxied(req) &&
+		req.Host != requestutil.GetRequestHost(req)
+}
+
+func (p *OAuthProxy) hasProxyPrefix(path string) bool {
+	return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix))
+}
+
+// getRdQuerystringRedirect handles this getAppRedirect strategy:
+// - `rd` querysting parameter
+func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string {
+	redirect := req.Form.Get("rd")
+	if p.IsValidRedirect(redirect) {
+		return redirect
+	}
+	return ""
+}
+
+// getXAuthRequestRedirect handles this getAppRedirect strategy:
+// - `X-Auth-Request-Redirect` Header
+func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string {
+	redirect := req.Header.Get("X-Auth-Request-Redirect")
+	if p.IsValidRedirect(redirect) {
+		return redirect
+	}
+	return ""
+}
+
+// getXForwardedHeadersRedirect handles these getAppRedirect strategies:
+// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
+// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
+func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string {
+	if !isForwardedRequest(req) {
+		return ""
+	}
+
+	uri := requestutil.GetRequestURI(req)
+	if p.hasProxyPrefix(uri) {
+		uri = "/"
+	}
+
+	redirect := fmt.Sprintf(
+		"%s://%s%s",
+		requestutil.GetRequestProto(req),
+		requestutil.GetRequestHost(req),
+		uri,
+	)
+
+	if p.IsValidRedirect(redirect) {
+		return redirect
+	}
+	return ""
+}
+
+// getURIRedirect handles these getAppRedirect strategies:
+// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
+// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
+// - `/`
+func (p *OAuthProxy) getURIRedirect(req *http.Request) string {
+	redirect := requestutil.GetRequestURI(req)
+	if !p.IsValidRedirect(redirect) {
+		redirect = req.URL.RequestURI()
+	}
+
+	if p.hasProxyPrefix(redirect) {
+		return "/"
+	}
+	return redirect
+}
+
+// splitHostPort separates host and port. If the port is not valid, it returns
+// the entire input as host, and it doesn't check the validity of the host.
+// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
+// *** taken from net/url, modified validOptionalPort() to accept ":*"
+func splitHostPort(hostport string) (host, port string) {
+	host = hostport
+
+	colon := strings.LastIndexByte(host, ':')
+	if colon != -1 && validOptionalPort(host[colon:]) {
+		host, port = host[:colon], host[colon+1:]
+	}
+
+	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
+		host = host[1 : len(host)-1]
+	}
+
+	return
+}
+
+// validOptionalPort reports whether port is either an empty string
+// or matches /^:\d*$/
+// *** taken from net/url, modified to accept ":*"
+func validOptionalPort(port string) bool {
+	if port == "" || port == ":*" {
+		return true
+	}
+	if port[0] != ':' {
+		return false
+	}
+	for _, b := range port[1:] {
+		if b < '0' || b > '9' {
+			return false
+		}
+	}
+	return true
 }
 
 // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
@@ -1153,7 +1155,7 @@ func extractAllowedGroups(req *http.Request) map[string]struct{} {
 }
 
 // addHeadersForProxying adds the appropriate headers the request / response for proxying
-func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) {
+func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) {
 	if session.Email == "" {
 		rw.Header().Set("GAP-Auth", session.User)
 	} else {
@@ -1181,8 +1183,8 @@ func isAjax(req *http.Request) bool {
 	return false
 }
 
-// ErrorJSON returns the error code with an application/json mime type
-func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
+// errorJSON returns the error code with an application/json mime type
+func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) {
 	rw.Header().Set("Content-Type", applicationJSON)
 	rw.WriteHeader(code)
 }
diff --git a/oauthproxy_test.go b/oauthproxy_test.go
index 8adea1ce..3366ef5f 100644
--- a/oauthproxy_test.go
+++ b/oauthproxy_test.go
@@ -415,8 +415,9 @@ func Test_redeemCode(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	_, err = proxy.redeemCode(context.Background(), "www.example.com", "")
-	assert.Error(t, err)
+	req := httptest.NewRequest(http.MethodGet, "/", nil)
+	_, err = proxy.redeemCode(req)
+	assert.Equal(t, providers.ErrMissingCode, err)
 }
 
 func Test_enrichSession(t *testing.T) {
@@ -1749,7 +1750,7 @@ func TestRequestSignature(t *testing.T) {
 	}
 }
 
-func TestGetRedirect(t *testing.T) {
+func Test_getAppRedirect(t *testing.T) {
 	opts := baseTestOptions()
 	opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443")
 	err := validation.Validate(opts)
@@ -1900,7 +1901,7 @@ func TestGetRedirect(t *testing.T) {
 			req = middleware.AddRequestScope(req, &middleware.RequestScope{
 				ReverseProxy: tt.reverseProxy,
 			})
-			redirect, err := proxy.GetAppRedirect(req)
+			redirect, err := proxy.getAppRedirect(req)
 
 			assert.NoError(t, err)
 			assert.Equal(t, tt.expectedRedirect, redirect)