1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-15 00:15:00 +02:00

Figure out final app redirect URL with proxy aware request utils

This commit is contained in:
Nick Meves
2021-01-02 13:46:05 -08:00
parent f054682fb7
commit 73fc7706bc
2 changed files with 111 additions and 59 deletions

View File

@ -24,9 +24,9 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
"github.com/oauth2-proxy/oauth2-proxy/v7/providers" "github.com/oauth2-proxy/oauth2-proxy/v7/providers"
) )
@ -98,7 +98,6 @@ type OAuthProxy struct {
SetAuthorization bool SetAuthorization bool
PassAuthorization bool PassAuthorization bool
PreferEmailToUser bool PreferEmailToUser bool
ReverseProxy bool
skipAuthPreflight bool skipAuthPreflight bool
skipJwtBearerTokens bool skipJwtBearerTokens bool
templates *template.Template templates *template.Template
@ -201,7 +200,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
UserInfoPath: fmt.Sprintf("%s/userinfo", opts.ProxyPrefix), UserInfoPath: fmt.Sprintf("%s/userinfo", opts.ProxyPrefix),
ProxyPrefix: opts.ProxyPrefix, ProxyPrefix: opts.ProxyPrefix,
ReverseProxy: opts.ReverseProxy,
provider: opts.GetProvider(), provider: opts.GetProvider(),
providerNameOverride: opts.ProviderName, providerNameOverride: opts.ProviderName,
sessionStore: sessionStore, sessionStore: sessionStore,
@ -231,7 +229,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
// the OAuth2 Proxy authentication logic kicks in. // the OAuth2 Proxy authentication logic kicks in.
// For example forcing HTTPS or health checks. // For example forcing HTTPS or health checks.
func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
chain := alice.New(middleware.NewScope(opts)) chain := alice.New(middleware.NewScope(opts.ReverseProxy))
if opts.ForceHTTPS { if opts.ForceHTTPS {
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
@ -368,9 +366,9 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) {
return routes, nil return routes, nil
} }
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will // GetOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
// redirect clients to once authenticated // redirect clients to once authenticated
func (p *OAuthProxy) GetRedirectURI(host string) string { func (p *OAuthProxy) GetOAuthRedirectURI(host string) string {
// default to the request Host if not set // default to the request Host if not set
if p.redirectURL.Host != "" { if p.redirectURL.Host != "" {
return p.redirectURL.String() return p.redirectURL.String()
@ -391,7 +389,7 @@ func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessio
if code == "" { if code == "" {
return nil, providers.ErrMissingCode return nil, providers.ErrMissingCode
} }
redirectURI := p.GetRedirectURI(host) redirectURI := p.GetOAuthRedirectURI(host)
s, err := p.provider.Redeem(ctx, redirectURI, code) s, err := p.provider.Redeem(ctx, redirectURI, code)
if err != nil { if err != nil {
return nil, err return nil, err
@ -420,7 +418,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains) cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)
if cookieDomain != "" { if cookieDomain != "" {
domain := util.GetRequestHost(req) domain := requestutil.GetRequestHost(req)
if h, _, err := net.SplitHostPort(domain); err == nil { if h, _, err := net.SplitHostPort(domain); err == nil {
domain = h domain = h
} }
@ -509,7 +507,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
} }
rw.WriteHeader(code) rw.WriteHeader(code)
redirectURL, err := p.GetRedirect(req) redirectURL, err := p.GetAppRedirect(req)
if err != nil { if err != nil {
logger.Errorf("Error obtaining redirect: %v", err) logger.Errorf("Error obtaining redirect: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
@ -568,46 +566,108 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) {
return "", false return "", false
} }
// GetRedirect reads the query parameter to get the URL to redirect clients to // GetAppRedirect determines the full URL or URI path to redirect clients to
// once authenticated with the OAuthProxy // once authenticated with the OAuthProxy
func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { // Strategy priority (first legal result is used):
err = req.ParseForm() // - `rd` querysting parameter
// - `X-Auth-Request-Redirect` header
// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
// - `/`
func (p *OAuthProxy) GetAppRedirect(req *http.Request) (string, error) {
err := req.ParseForm()
if err != nil { if err != nil {
return return "", err
} }
redirect = req.Header.Get("X-Auth-Request-Redirect") // These redirect getter functions are strategies ordered by priority
if req.Form.Get("rd") != "" { // for figuring out the redirect URL.
redirect = req.Form.Get("rd") type redirectGetter func(req *http.Request) string
} for _, rdGetter := range []redirectGetter{
// Quirk: On reverse proxies that doesn't have support for p.getRdQuerystringRedirect,
// "X-Auth-Request-Redirect" header or dynamic header/query string p.getXAuthRequestRedirect,
// manipulation (like Traefik v1 and v2), we can try if the header p.getXForwardedHeadersRedirect,
// X-Forwarded-Host exists or not. p.getURIRedirect,
if redirect == "" && isForwardedRequest(req, p.ReverseProxy) { } {
redirect = p.getRedirectFromForwardHeaders(req) if redirect := rdGetter(req); redirect != "" {
} return redirect, nil
if !p.IsValidRedirect(redirect) {
// Use RequestURI to preserve ?query
redirect = req.URL.RequestURI()
if strings.HasPrefix(redirect, fmt.Sprintf("%s/", p.ProxyPrefix)) {
redirect = "/"
} }
} }
return return "/", nil
} }
// getRedirectFromForwardHeaders returns the redirect URL based on X-Forwarded-{Proto,Host,Uri} headers func isForwardedRequest(req *http.Request) bool {
func (p *OAuthProxy) getRedirectFromForwardHeaders(req *http.Request) string { return requestutil.IsProxied(req) &&
uri := util.GetRequestURI(req) req.Host != requestutil.GetRequestHost(req)
}
if strings.HasPrefix(uri, fmt.Sprintf("%s/", p.ProxyPrefix)) { func (p *OAuthProxy) hasProxyPrefix(path string) bool {
return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix))
}
// getRdQuerystringRedirect handles this GetAppRedirect strategy:
// - `rd` querysting parameter
func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string {
redirect := req.Form.Get("rd")
if p.IsValidRedirect(redirect) {
return redirect
}
return ""
}
// getXAuthRequestRedirect handles this GetAppRedirect strategy:
// - `X-Auth-Request-Redirect` Header
func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string {
redirect := req.Header.Get("X-Auth-Request-Redirect")
if p.IsValidRedirect(redirect) {
return redirect
}
return ""
}
// getXForwardedHeadersRedirect handles these GetAppRedirect strategies:
// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string {
if !isForwardedRequest(req) {
return ""
}
uri := requestutil.GetRequestURI(req)
if p.hasProxyPrefix(uri) {
uri = "/" uri = "/"
} }
return fmt.Sprintf("%s://%s%s", util.GetRequestProto(req), util.GetRequestHost(req), uri) redirect := fmt.Sprintf(
"%s://%s%s",
requestutil.GetRequestProto(req),
requestutil.GetRequestHost(req),
uri,
)
if p.IsValidRedirect(redirect) {
return redirect
}
return ""
}
// getURIRedirect handles these GetAppRedirect strategies:
// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
// - `/`
func (p *OAuthProxy) getURIRedirect(req *http.Request) string {
redirect := requestutil.GetRequestURI(req)
if !p.IsValidRedirect(redirect) {
redirect = req.URL.RequestURI()
}
if p.hasProxyPrefix(redirect) {
return "/"
}
return redirect
} }
// splitHostPort separates host and port. If the port is not valid, it returns // splitHostPort separates host and port. If the port is not valid, it returns
@ -707,12 +767,6 @@ func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
return false return false
} }
// isForwardedRequest is used to check if X-Forwarded-Host header exists or not
func isForwardedRequest(req *http.Request, reverseProxy bool) bool {
isForwarded := req.Host != util.GetRequestHost(req)
return isForwarded && reverseProxy
}
// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
var noCacheHeaders = map[string]string{ var noCacheHeaders = map[string]string{
"Expires": time.Unix(0, 0).Format(time.RFC1123), "Expires": time.Unix(0, 0).Format(time.RFC1123),
@ -781,7 +835,7 @@ func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
// SignIn serves a page prompting users to sign in // SignIn serves a page prompting users to sign in
func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.GetRedirect(req) redirect, err := p.GetAppRedirect(req)
if err != nil { if err != nil {
logger.Errorf("Error obtaining redirect: %v", err) logger.Errorf("Error obtaining redirect: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
@ -839,7 +893,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) {
// SignOut sends a response to clear the authentication cookie // SignOut sends a response to clear the authentication cookie
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.GetRedirect(req) redirect, err := p.GetAppRedirect(req)
if err != nil { if err != nil {
logger.Errorf("Error obtaining redirect: %v", err) logger.Errorf("Error obtaining redirect: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
@ -864,13 +918,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
return return
} }
p.SetCSRFCookie(rw, req, nonce) p.SetCSRFCookie(rw, req, nonce)
redirect, err := p.GetRedirect(req) redirect, err := p.GetAppRedirect(req)
if err != nil { if err != nil {
logger.Errorf("Error obtaining redirect: %v", err) logger.Errorf("Error obtaining redirect: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
return return
} }
redirectURI := p.GetRedirectURI(util.GetRequestHost(req)) redirectURI := p.GetOAuthRedirectURI(requestutil.GetRequestHost(req))
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
} }
@ -893,7 +947,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return return
} }
session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code")) session, err := p.redeemCode(req.Context(), requestutil.GetRequestHost(req), req.Form.Get("code"))
if err != nil { if err != nil {
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
@ -1024,7 +1078,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
var session *sessionsapi.SessionState var session *sessionsapi.SessionState
getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
session = middleware.GetRequestScope(req).Session session = middlewareapi.GetRequestScope(req).Session
})) }))
getSession.ServeHTTP(rw, req) getSession.ServeHTTP(rw, req)

View File

@ -19,6 +19,7 @@ import (
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
@ -1750,8 +1751,7 @@ func TestRequestSignature(t *testing.T) {
func TestGetRedirect(t *testing.T) { func TestGetRedirect(t *testing.T) {
opts := baseTestOptions() opts := baseTestOptions()
opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com") opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443")
opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com:8443")
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
require.NotEmpty(t, opts.ProxyPrefix) require.NotEmpty(t, opts.ProxyPrefix)
@ -1854,9 +1854,6 @@ func TestGetRedirect(t *testing.T) {
url: "https://oauth.example.com/foo/bar", url: "https://oauth.example.com/foo/bar",
headers: map[string]string{ headers: map[string]string{
"X-Auth-Request-Redirect": "https://a-service.example.com/foo/bar", "X-Auth-Request-Redirect": "https://a-service.example.com/foo/bar",
"X-Forwarded-Proto": "",
"X-Forwarded-Host": "",
"X-Forwarded-Uri": "",
}, },
reverseProxy: true, reverseProxy: true,
expectedRedirect: "https://a-service.example.com/foo/bar", expectedRedirect: "https://a-service.example.com/foo/bar",
@ -1884,10 +1881,9 @@ func TestGetRedirect(t *testing.T) {
name: "proxied request with rd query string and some headers set redirects to proxied URL on rd query string", name: "proxied request with rd query string and some headers set redirects to proxied URL on rd query string",
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbaz", url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbaz",
headers: map[string]string{ headers: map[string]string{
"X-Auth-Request-Redirect": "", "X-Forwarded-Proto": "https",
"X-Forwarded-Proto": "https", "X-Forwarded-Host": "another-service.example.com",
"X-Forwarded-Host": "another-service.example.com", "X-Forwarded-Uri": "/seasons/greetings",
"X-Forwarded-Uri": "/seasons/greetings",
}, },
reverseProxy: true, reverseProxy: true,
expectedRedirect: "https://a-service.example.com/foo/baz", expectedRedirect: "https://a-service.example.com/foo/baz",
@ -1901,8 +1897,10 @@ func TestGetRedirect(t *testing.T) {
req.Header.Add(header, value) req.Header.Add(header, value)
} }
} }
proxy.ReverseProxy = tt.reverseProxy req = middleware.AddRequestScope(req, &middleware.RequestScope{
redirect, err := proxy.GetRedirect(req) ReverseProxy: tt.reverseProxy,
})
redirect, err := proxy.GetAppRedirect(req)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tt.expectedRedirect, redirect) assert.Equal(t, tt.expectedRedirect, redirect)