mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-01-08 04:03:58 +02:00
Merge pull request #964 from grnhse/reverse-proxy-context
Track the ReverseProxy config setting in the request Scope
This commit is contained in:
commit
9c126f5740
@ -4,6 +4,8 @@
|
||||
|
||||
## Important Notes
|
||||
|
||||
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Redirect URL generation will attempt secondary strategies
|
||||
in the priority chain if any fail the `IsValidRedirect` security check. Previously any failures fell back to `/`.
|
||||
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Keycloak will now use `--profile-url` if set for the userinfo endpoint
|
||||
instead of `--validate-url`. `--validate-url` will still work for backwards compatibility.
|
||||
- [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) To use X-Forwarded-{Proto,Host,Uri} on redirect detection, `--reverse-proxy` must be `true`.
|
||||
@ -36,6 +38,11 @@
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) `--reverse-proxy` must be true to trust `X-Forwarded-*` headers as canonical.
|
||||
These are used throughout the application in redirect URLs, cookie domains and host logging logic. These are the headers:
|
||||
- `X-Forwarded-Proto` instead of `req.URL.Scheme`
|
||||
- `X-Forwarded-Host` instead of `req.Host`
|
||||
- `X-Forwarded-Uri` instead of `req.URL.RequestURI()`
|
||||
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) In config files & envvar configs, `keycloak_group` is now the plural `keycloak_groups`.
|
||||
Flag configs are still `--keycloak-group` but it can be passed multiple times.
|
||||
- [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google".
|
||||
@ -60,6 +67,7 @@
|
||||
## Changes since v6.1.1
|
||||
|
||||
- [#995](https://github.com/oauth2-proxy/oauth2-proxy/pull/995) Add Security Policy (@JoelSpeed)
|
||||
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Require `--reverse-proxy` true to trust `X-Forwareded-*` type headers (@NickMeves)
|
||||
- [#970](https://github.com/oauth2-proxy/oauth2-proxy/pull/970) Fix joined cookie name for those containing underline in the suffix (@peppered)
|
||||
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Migrate Keycloak to EnrichSession & support multiple groups for authorization (@NickMeves)
|
||||
- [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) Use X-Forwarded-{Proto,Host,Uri} on redirect as last resort (@linuxgemini)
|
||||
|
617
oauthproxy.go
617
oauthproxy.go
@ -24,16 +24,14 @@ import (
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
||||
)
|
||||
|
||||
const (
|
||||
httpScheme = "http"
|
||||
httpsScheme = "https"
|
||||
|
||||
schemeHTTPS = "https"
|
||||
applicationJSON = "application/json"
|
||||
)
|
||||
|
||||
@ -98,7 +96,6 @@ type OAuthProxy struct {
|
||||
SetAuthorization bool
|
||||
PassAuthorization bool
|
||||
PreferEmailToUser bool
|
||||
ReverseProxy bool
|
||||
skipAuthPreflight bool
|
||||
skipJwtBearerTokens bool
|
||||
templates *template.Template
|
||||
@ -201,7 +198,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
UserInfoPath: fmt.Sprintf("%s/userinfo", opts.ProxyPrefix),
|
||||
|
||||
ProxyPrefix: opts.ProxyPrefix,
|
||||
ReverseProxy: opts.ReverseProxy,
|
||||
provider: opts.GetProvider(),
|
||||
providerNameOverride: opts.ProviderName,
|
||||
sessionStore: sessionStore,
|
||||
@ -231,7 +227,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
// the OAuth2 Proxy authentication logic kicks in.
|
||||
// For example forcing HTTPS or health checks.
|
||||
func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
|
||||
chain := alice.New(middleware.NewScope())
|
||||
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
|
||||
|
||||
if opts.ForceHTTPS {
|
||||
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
||||
@ -368,49 +364,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) {
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
|
||||
// redirect clients to once authenticated
|
||||
func (p *OAuthProxy) GetRedirectURI(host string) string {
|
||||
// default to the request Host if not set
|
||||
if p.redirectURL.Host != "" {
|
||||
return p.redirectURL.String()
|
||||
}
|
||||
u := *p.redirectURL
|
||||
if u.Scheme == "" {
|
||||
if p.CookieSecure {
|
||||
u.Scheme = httpsScheme
|
||||
} else {
|
||||
u.Scheme = httpScheme
|
||||
}
|
||||
}
|
||||
u.Host = host
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
|
||||
if code == "" {
|
||||
return nil, providers.ErrMissingCode
|
||||
}
|
||||
redirectURI := p.GetRedirectURI(host)
|
||||
s, err := p.provider.Redeem(ctx, redirectURI, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
|
||||
var err error
|
||||
if s.Email == "" {
|
||||
s.Email, err = p.provider.GetEmailAddress(ctx, s)
|
||||
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return p.provider.EnrichSession(ctx, s)
|
||||
}
|
||||
|
||||
// MakeCSRFCookie creates a cookie for CSRF
|
||||
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
|
||||
@ -420,7 +373,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
|
||||
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)
|
||||
|
||||
if cookieDomain != "" {
|
||||
domain := util.GetRequestHost(req)
|
||||
domain := requestutil.GetRequestHost(req)
|
||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||
domain = h
|
||||
}
|
||||
@ -468,6 +421,81 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s
|
||||
return p.sessionStore.Save(rw, req, s)
|
||||
}
|
||||
|
||||
// IsValidRedirect checks whether the redirect URL is whitelisted
|
||||
func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
|
||||
switch {
|
||||
case redirect == "":
|
||||
// The user didn't specify a redirect, should fallback to `/`
|
||||
return false
|
||||
case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
|
||||
return true
|
||||
case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
|
||||
redirectURL, err := url.Parse(redirect)
|
||||
if err != nil {
|
||||
logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
|
||||
return false
|
||||
}
|
||||
redirectHostname := redirectURL.Hostname()
|
||||
|
||||
for _, domain := range p.whitelistDomains {
|
||||
domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
|
||||
if domainHostname == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
|
||||
// the domain names match, now validate the ports
|
||||
// if the whitelisted domain's port is '*', allow all ports
|
||||
// if the whitelisted domain contains a specific port, only allow that port
|
||||
// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
|
||||
redirectPort := redirectURL.Port()
|
||||
if (domainPort == "*") ||
|
||||
(domainPort == redirectPort) ||
|
||||
(domainPort == "" && redirectPort == "") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
|
||||
return false
|
||||
default:
|
||||
logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
|
||||
prepareNoCache(rw)
|
||||
}
|
||||
|
||||
switch path := req.URL.Path; {
|
||||
case path == p.RobotsPath:
|
||||
p.RobotsTxt(rw)
|
||||
case p.IsAllowedRequest(req):
|
||||
p.SkipAuthProxy(rw, req)
|
||||
case path == p.SignInPath:
|
||||
p.SignIn(rw, req)
|
||||
case path == p.SignOutPath:
|
||||
p.SignOut(rw, req)
|
||||
case path == p.OAuthStartPath:
|
||||
p.OAuthStart(rw, req)
|
||||
case path == p.OAuthCallbackPath:
|
||||
p.OAuthCallback(rw, req)
|
||||
case path == p.AuthOnlyPath:
|
||||
p.AuthOnly(rw, req)
|
||||
case path == p.UserInfoPath:
|
||||
p.UserInfo(rw, req)
|
||||
default:
|
||||
p.Proxy(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// RobotsTxt disallows scraping pages from the OAuthProxy
|
||||
func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
|
||||
_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
|
||||
@ -498,6 +526,42 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
|
||||
}
|
||||
}
|
||||
|
||||
// IsAllowedRequest is used to check if auth should be skipped for this request
|
||||
func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
|
||||
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
|
||||
return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req)
|
||||
}
|
||||
|
||||
// IsAllowedRoute is used to check if the request method & path is allowed without auth
|
||||
func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
|
||||
for _, route := range p.allowedRoutes {
|
||||
if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTrustedIP is used to check if a request comes from a trusted client IP address.
|
||||
func (p *OAuthProxy) isTrustedIP(req *http.Request) bool {
|
||||
if p.trustedIPs == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
|
||||
// Possibly spoofed X-Real-IP header
|
||||
return false
|
||||
}
|
||||
|
||||
if remoteAddr == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.trustedIPs.Has(remoteAddr)
|
||||
}
|
||||
|
||||
// SignInPage writes the sing in template to the response
|
||||
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
||||
prepareNoCache(rw)
|
||||
@ -509,7 +573,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
|
||||
}
|
||||
rw.WriteHeader(code)
|
||||
|
||||
redirectURL, err := p.GetRedirect(req)
|
||||
redirectURL, err := p.getAppRedirect(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining redirect: %v", err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||
@ -568,220 +632,9 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetRedirect reads the query parameter to get the URL to redirect clients to
|
||||
// once authenticated with the OAuthProxy
|
||||
func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
|
||||
err = req.ParseForm()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
redirect = req.Header.Get("X-Auth-Request-Redirect")
|
||||
if req.Form.Get("rd") != "" {
|
||||
redirect = req.Form.Get("rd")
|
||||
}
|
||||
// Quirk: On reverse proxies that doesn't have support for
|
||||
// "X-Auth-Request-Redirect" header or dynamic header/query string
|
||||
// manipulation (like Traefik v1 and v2), we can try if the header
|
||||
// X-Forwarded-Host exists or not.
|
||||
if redirect == "" && isForwardedRequest(req, p.ReverseProxy) {
|
||||
redirect = p.getRedirectFromForwardHeaders(req)
|
||||
}
|
||||
if !p.IsValidRedirect(redirect) {
|
||||
// Use RequestURI to preserve ?query
|
||||
redirect = req.URL.RequestURI()
|
||||
|
||||
if strings.HasPrefix(redirect, fmt.Sprintf("%s/", p.ProxyPrefix)) {
|
||||
redirect = "/"
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// getRedirectFromForwardHeaders returns the redirect URL based on X-Forwarded-{Proto,Host,Uri} headers
|
||||
func (p *OAuthProxy) getRedirectFromForwardHeaders(req *http.Request) string {
|
||||
uri := util.GetRequestURI(req)
|
||||
|
||||
if strings.HasPrefix(uri, fmt.Sprintf("%s/", p.ProxyPrefix)) {
|
||||
uri = "/"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s%s", util.GetRequestProto(req), util.GetRequestHost(req), uri)
|
||||
}
|
||||
|
||||
// splitHostPort separates host and port. If the port is not valid, it returns
|
||||
// the entire input as host, and it doesn't check the validity of the host.
|
||||
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
|
||||
// *** taken from net/url, modified validOptionalPort() to accept ":*"
|
||||
func splitHostPort(hostport string) (host, port string) {
|
||||
host = hostport
|
||||
|
||||
colon := strings.LastIndexByte(host, ':')
|
||||
if colon != -1 && validOptionalPort(host[colon:]) {
|
||||
host, port = host[:colon], host[colon+1:]
|
||||
}
|
||||
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
host = host[1 : len(host)-1]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// validOptionalPort reports whether port is either an empty string
|
||||
// or matches /^:\d*$/
|
||||
// *** taken from net/url, modified to accept ":*"
|
||||
func validOptionalPort(port string) bool {
|
||||
if port == "" || port == ":*" {
|
||||
return true
|
||||
}
|
||||
if port[0] != ':' {
|
||||
return false
|
||||
}
|
||||
for _, b := range port[1:] {
|
||||
if b < '0' || b > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsValidRedirect checks whether the redirect URL is whitelisted
|
||||
func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
|
||||
switch {
|
||||
case redirect == "":
|
||||
// The user didn't specify a redirect, should fallback to `/`
|
||||
return false
|
||||
case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
|
||||
return true
|
||||
case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
|
||||
redirectURL, err := url.Parse(redirect)
|
||||
if err != nil {
|
||||
logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
|
||||
return false
|
||||
}
|
||||
redirectHostname := redirectURL.Hostname()
|
||||
|
||||
for _, domain := range p.whitelistDomains {
|
||||
domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
|
||||
if domainHostname == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
|
||||
// the domain names match, now validate the ports
|
||||
// if the whitelisted domain's port is '*', allow all ports
|
||||
// if the whitelisted domain contains a specific port, only allow that port
|
||||
// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
|
||||
redirectPort := redirectURL.Port()
|
||||
if (domainPort == "*") ||
|
||||
(domainPort == redirectPort) ||
|
||||
(domainPort == "" && redirectPort == "") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
|
||||
return false
|
||||
default:
|
||||
logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsAllowedRequest is used to check if auth should be skipped for this request
|
||||
func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
|
||||
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
|
||||
return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.IsTrustedIP(req)
|
||||
}
|
||||
|
||||
// IsAllowedRoute is used to check if the request method & path is allowed without auth
|
||||
func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
|
||||
for _, route := range p.allowedRoutes {
|
||||
if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isForwardedRequest is used to check if X-Forwarded-Host header exists or not
|
||||
func isForwardedRequest(req *http.Request, reverseProxy bool) bool {
|
||||
isForwarded := req.Host != util.GetRequestHost(req)
|
||||
return isForwarded && reverseProxy
|
||||
}
|
||||
|
||||
// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
|
||||
var noCacheHeaders = map[string]string{
|
||||
"Expires": time.Unix(0, 0).Format(time.RFC1123),
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate, max-age=0",
|
||||
"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
|
||||
}
|
||||
|
||||
// prepareNoCache prepares headers for preventing browser caching.
|
||||
func prepareNoCache(w http.ResponseWriter) {
|
||||
// Set NoCache headers
|
||||
for k, v := range noCacheHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// IsTrustedIP is used to check if a request comes from a trusted client IP address.
|
||||
func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool {
|
||||
if p.trustedIPs == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
|
||||
// Possibly spoofed X-Real-IP header
|
||||
return false
|
||||
}
|
||||
|
||||
if remoteAddr == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.trustedIPs.Has(remoteAddr)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
|
||||
prepareNoCache(rw)
|
||||
}
|
||||
|
||||
switch path := req.URL.Path; {
|
||||
case path == p.RobotsPath:
|
||||
p.RobotsTxt(rw)
|
||||
case p.IsAllowedRequest(req):
|
||||
p.SkipAuthProxy(rw, req)
|
||||
case path == p.SignInPath:
|
||||
p.SignIn(rw, req)
|
||||
case path == p.SignOutPath:
|
||||
p.SignOut(rw, req)
|
||||
case path == p.OAuthStartPath:
|
||||
p.OAuthStart(rw, req)
|
||||
case path == p.OAuthCallbackPath:
|
||||
p.OAuthCallback(rw, req)
|
||||
case path == p.AuthOnlyPath:
|
||||
p.AuthOnly(rw, req)
|
||||
case path == p.UserInfoPath:
|
||||
p.UserInfo(rw, req)
|
||||
default:
|
||||
p.Proxy(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// SignIn serves a page prompting users to sign in
|
||||
func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||
redirect, err := p.GetRedirect(req)
|
||||
redirect, err := p.getAppRedirect(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining redirect: %v", err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||
@ -839,7 +692,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
// SignOut sends a response to clear the authentication cookie
|
||||
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
|
||||
redirect, err := p.GetRedirect(req)
|
||||
redirect, err := p.getAppRedirect(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining redirect: %v", err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||
@ -864,13 +717,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
p.SetCSRFCookie(rw, req, nonce)
|
||||
redirect, err := p.GetRedirect(req)
|
||||
redirect, err := p.getAppRedirect(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining redirect: %v", err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||
return
|
||||
}
|
||||
redirectURI := p.GetRedirectURI(util.GetRequestHost(req))
|
||||
redirectURI := p.getOAuthRedirectURI(req)
|
||||
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
|
||||
}
|
||||
|
||||
@ -893,7 +746,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code"))
|
||||
session, err := p.redeemCode(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
|
||||
@ -952,6 +805,32 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||
code := req.Form.Get("code")
|
||||
if code == "" {
|
||||
return nil, providers.ErrMissingCode
|
||||
}
|
||||
|
||||
redirectURI := p.getOAuthRedirectURI(req)
|
||||
s, err := p.provider.Redeem(req.Context(), redirectURI, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
|
||||
var err error
|
||||
if s.Email == "" {
|
||||
s.Email, err = p.provider.GetEmailAddress(ctx, s)
|
||||
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return p.provider.EnrichSession(ctx, s)
|
||||
}
|
||||
|
||||
// AuthOnly checks whether the user is currently logged in (both authentication
|
||||
// and optional authorization).
|
||||
func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) {
|
||||
@ -969,7 +848,7 @@ func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
// we are authenticated
|
||||
p.addHeadersForProxying(rw, req, session)
|
||||
p.addHeadersForProxying(rw, session)
|
||||
p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
})).ServeHTTP(rw, req)
|
||||
@ -987,13 +866,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
||||
switch err {
|
||||
case nil:
|
||||
// we are authenticated
|
||||
p.addHeadersForProxying(rw, req, session)
|
||||
p.addHeadersForProxying(rw, session)
|
||||
p.headersChain.Then(p.serveMux).ServeHTTP(rw, req)
|
||||
case ErrNeedsLogin:
|
||||
// we need to send the user to a login screen
|
||||
if isAjax(req) {
|
||||
// no point redirecting an AJAX request
|
||||
p.ErrorJSON(rw, http.StatusUnauthorized)
|
||||
p.errorJSON(rw, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
@ -1012,7 +891,195 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
||||
p.ErrorPage(rw, http.StatusInternalServerError,
|
||||
"Internal Error", "Internal Error")
|
||||
}
|
||||
}
|
||||
|
||||
// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
|
||||
var noCacheHeaders = map[string]string{
|
||||
"Expires": time.Unix(0, 0).Format(time.RFC1123),
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate, max-age=0",
|
||||
"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
|
||||
}
|
||||
|
||||
// prepareNoCache prepares headers for preventing browser caching.
|
||||
func prepareNoCache(w http.ResponseWriter) {
|
||||
// Set NoCache headers
|
||||
for k, v := range noCacheHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// getOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
|
||||
// redirect clients to once authenticated.
|
||||
// This is usually the OAuthProxy callback URL.
|
||||
func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string {
|
||||
// if `p.redirectURL` already has a host, return it
|
||||
if p.redirectURL.Host != "" {
|
||||
return p.redirectURL.String()
|
||||
}
|
||||
|
||||
// Otherwise figure out the scheme + host from the request
|
||||
rd := *p.redirectURL
|
||||
rd.Host = requestutil.GetRequestHost(req)
|
||||
rd.Scheme = requestutil.GetRequestProto(req)
|
||||
|
||||
// If CookieSecure is true, return `https` no matter what
|
||||
// Not all reverse proxies set X-Forwarded-Proto
|
||||
if p.CookieSecure {
|
||||
rd.Scheme = schemeHTTPS
|
||||
}
|
||||
return rd.String()
|
||||
}
|
||||
|
||||
// getAppRedirect determines the full URL or URI path to redirect clients to
|
||||
// once authenticated with the OAuthProxy
|
||||
// Strategy priority (first legal result is used):
|
||||
// - `rd` querysting parameter
|
||||
// - `X-Auth-Request-Redirect` header
|
||||
// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
|
||||
// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
|
||||
// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
|
||||
// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
|
||||
// - `/`
|
||||
func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// These redirect getter functions are strategies ordered by priority
|
||||
// for figuring out the redirect URL.
|
||||
type redirectGetter func(req *http.Request) string
|
||||
for _, rdGetter := range []redirectGetter{
|
||||
p.getRdQuerystringRedirect,
|
||||
p.getXAuthRequestRedirect,
|
||||
p.getXForwardedHeadersRedirect,
|
||||
p.getURIRedirect,
|
||||
} {
|
||||
redirect := rdGetter(req)
|
||||
// Call `p.IsValidRedirect` again here a final time to be safe
|
||||
if redirect != "" && p.IsValidRedirect(redirect) {
|
||||
return redirect, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "/", nil
|
||||
}
|
||||
|
||||
func isForwardedRequest(req *http.Request) bool {
|
||||
return requestutil.IsProxied(req) &&
|
||||
req.Host != requestutil.GetRequestHost(req)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) hasProxyPrefix(path string) bool {
|
||||
return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix))
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) validateRedirect(redirect string, errorFormat string) string {
|
||||
if p.IsValidRedirect(redirect) {
|
||||
return redirect
|
||||
}
|
||||
if redirect != "" {
|
||||
logger.Errorf(errorFormat, redirect)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRdQuerystringRedirect handles this getAppRedirect strategy:
|
||||
// - `rd` querysting parameter
|
||||
func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string {
|
||||
return p.validateRedirect(
|
||||
req.Form.Get("rd"),
|
||||
"Invalid redirect provided in rd querystring parameter: %s",
|
||||
)
|
||||
}
|
||||
|
||||
// getXAuthRequestRedirect handles this getAppRedirect strategy:
|
||||
// - `X-Auth-Request-Redirect` Header
|
||||
func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string {
|
||||
return p.validateRedirect(
|
||||
req.Header.Get("X-Auth-Request-Redirect"),
|
||||
"Invalid redirect provided in X-Auth-Request-Redirect header: %s",
|
||||
)
|
||||
}
|
||||
|
||||
// getXForwardedHeadersRedirect handles these getAppRedirect strategies:
|
||||
// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
|
||||
// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
|
||||
func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string {
|
||||
if !isForwardedRequest(req) {
|
||||
return ""
|
||||
}
|
||||
|
||||
uri := requestutil.GetRequestURI(req)
|
||||
if p.hasProxyPrefix(uri) {
|
||||
uri = "/"
|
||||
}
|
||||
|
||||
redirect := fmt.Sprintf(
|
||||
"%s://%s%s",
|
||||
requestutil.GetRequestProto(req),
|
||||
requestutil.GetRequestHost(req),
|
||||
uri,
|
||||
)
|
||||
|
||||
return p.validateRedirect(redirect,
|
||||
"Invalid redirect generated from X-Forwarded-* headers: %s")
|
||||
}
|
||||
|
||||
// getURIRedirect handles these getAppRedirect strategies:
|
||||
// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
|
||||
// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
|
||||
// - `/`
|
||||
func (p *OAuthProxy) getURIRedirect(req *http.Request) string {
|
||||
redirect := p.validateRedirect(
|
||||
requestutil.GetRequestURI(req),
|
||||
"Invalid redirect generated from X-Forwarded-Uri header: %s",
|
||||
)
|
||||
if redirect == "" {
|
||||
redirect = req.URL.RequestURI()
|
||||
}
|
||||
|
||||
if p.hasProxyPrefix(redirect) {
|
||||
return "/"
|
||||
}
|
||||
return redirect
|
||||
}
|
||||
|
||||
// splitHostPort separates host and port. If the port is not valid, it returns
|
||||
// the entire input as host, and it doesn't check the validity of the host.
|
||||
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
|
||||
// *** taken from net/url, modified validOptionalPort() to accept ":*"
|
||||
func splitHostPort(hostport string) (host, port string) {
|
||||
host = hostport
|
||||
|
||||
colon := strings.LastIndexByte(host, ':')
|
||||
if colon != -1 && validOptionalPort(host[colon:]) {
|
||||
host, port = host[:colon], host[colon+1:]
|
||||
}
|
||||
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
host = host[1 : len(host)-1]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// validOptionalPort reports whether port is either an empty string
|
||||
// or matches /^:\d*$/
|
||||
// *** taken from net/url, modified to accept ":*"
|
||||
func validOptionalPort(port string) bool {
|
||||
if port == "" || port == ":*" {
|
||||
return true
|
||||
}
|
||||
if port[0] != ':' {
|
||||
return false
|
||||
}
|
||||
for _, b := range port[1:] {
|
||||
if b < '0' || b > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
|
||||
@ -1024,7 +1091,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
|
||||
var session *sessionsapi.SessionState
|
||||
|
||||
getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
session = middleware.GetRequestScope(req).Session
|
||||
session = middlewareapi.GetRequestScope(req).Session
|
||||
}))
|
||||
getSession.ServeHTTP(rw, req)
|
||||
|
||||
@ -1099,7 +1166,7 @@ func extractAllowedGroups(req *http.Request) map[string]struct{} {
|
||||
}
|
||||
|
||||
// addHeadersForProxying adds the appropriate headers the request / response for proxying
|
||||
func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) {
|
||||
func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) {
|
||||
if session.Email == "" {
|
||||
rw.Header().Set("GAP-Auth", session.User)
|
||||
} else {
|
||||
@ -1127,8 +1194,8 @@ func isAjax(req *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ErrorJSON returns the error code with an application/json mime type
|
||||
func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
|
||||
// errorJSON returns the error code with an application/json mime type
|
||||
func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) {
|
||||
rw.Header().Set("Content-Type", applicationJSON)
|
||||
rw.WriteHeader(code)
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
@ -414,8 +415,9 @@ func Test_redeemCode(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = proxy.redeemCode(context.Background(), "www.example.com", "")
|
||||
assert.Error(t, err)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
_, err = proxy.redeemCode(req)
|
||||
assert.Equal(t, providers.ErrMissingCode, err)
|
||||
}
|
||||
|
||||
func Test_enrichSession(t *testing.T) {
|
||||
@ -1748,10 +1750,9 @@ func TestRequestSignature(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRedirect(t *testing.T) {
|
||||
func Test_getAppRedirect(t *testing.T) {
|
||||
opts := baseTestOptions()
|
||||
opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com")
|
||||
opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com:8443")
|
||||
opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443")
|
||||
err := validation.Validate(opts)
|
||||
assert.NoError(t, err)
|
||||
require.NotEmpty(t, opts.ProxyPrefix)
|
||||
@ -1854,9 +1855,6 @@ func TestGetRedirect(t *testing.T) {
|
||||
url: "https://oauth.example.com/foo/bar",
|
||||
headers: map[string]string{
|
||||
"X-Auth-Request-Redirect": "https://a-service.example.com/foo/bar",
|
||||
"X-Forwarded-Proto": "",
|
||||
"X-Forwarded-Host": "",
|
||||
"X-Forwarded-Uri": "",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedRedirect: "https://a-service.example.com/foo/bar",
|
||||
@ -1884,10 +1882,9 @@ func TestGetRedirect(t *testing.T) {
|
||||
name: "proxied request with rd query string and some headers set redirects to proxied URL on rd query string",
|
||||
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbaz",
|
||||
headers: map[string]string{
|
||||
"X-Auth-Request-Redirect": "",
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "another-service.example.com",
|
||||
"X-Forwarded-Uri": "/seasons/greetings",
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "another-service.example.com",
|
||||
"X-Forwarded-Uri": "/seasons/greetings",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedRedirect: "https://a-service.example.com/foo/baz",
|
||||
@ -1901,8 +1898,10 @@ func TestGetRedirect(t *testing.T) {
|
||||
req.Header.Add(header, value)
|
||||
}
|
||||
}
|
||||
proxy.ReverseProxy = tt.reverseProxy
|
||||
redirect, err := proxy.GetRedirect(req)
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: tt.reverseProxy,
|
||||
})
|
||||
redirect, err := proxy.getAppRedirect(req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedRedirect, redirect)
|
||||
|
19
pkg/apis/middleware/middleware_suite_test.go
Normal file
19
pkg/apis/middleware/middleware_suite_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// TestMiddlewareSuite and related tests are in a *_test package
|
||||
// to prevent circular imports with the `logger` package which uses
|
||||
// this functionality
|
||||
func TestMiddlewareSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Middleware API")
|
||||
}
|
@ -1,13 +1,26 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
type scopeKey string
|
||||
|
||||
// RequestScopeKey uses a typed string to reduce likelihood of clashing
|
||||
// with other context keys
|
||||
const RequestScopeKey scopeKey = "request-scope"
|
||||
|
||||
// RequestScope contains information regarding the request that is being made.
|
||||
// The RequestScope is used to pass information between different middlewares
|
||||
// within the chain.
|
||||
type RequestScope struct {
|
||||
// ReverseProxy tracks whether OAuth2-Proxy is operating in reverse proxy
|
||||
// mode and if request `X-Forwarded-*` headers should be trusted
|
||||
ReverseProxy bool
|
||||
|
||||
// Session details the authenticated users information (if it exists).
|
||||
Session *sessions.SessionState
|
||||
|
||||
@ -22,3 +35,19 @@ type RequestScope struct {
|
||||
// it was loaded or not.
|
||||
SessionRevalidated bool
|
||||
}
|
||||
|
||||
// GetRequestScope returns the current request scope from the given request
|
||||
func GetRequestScope(req *http.Request) *RequestScope {
|
||||
scope := req.Context().Value(RequestScopeKey)
|
||||
if scope == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return scope.(*RequestScope)
|
||||
}
|
||||
|
||||
// AddRequestScope adds a RequestScope to a request
|
||||
func AddRequestScope(req *http.Request, scope *RequestScope) *http.Request {
|
||||
ctx := context.WithValue(req.Context(), RequestScopeKey, scope)
|
||||
return req.WithContext(ctx)
|
||||
}
|
||||
|
56
pkg/apis/middleware/scope_test.go
Normal file
56
pkg/apis/middleware/scope_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Scope Suite", func() {
|
||||
Context("GetRequestScope", func() {
|
||||
var request *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("with a scope", func() {
|
||||
var scope *middleware.RequestScope
|
||||
|
||||
BeforeEach(func() {
|
||||
scope = &middleware.RequestScope{}
|
||||
request = middleware.AddRequestScope(request, scope)
|
||||
})
|
||||
|
||||
It("returns the scope", func() {
|
||||
s := middleware.GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
})
|
||||
|
||||
Context("if the scope is then modified", func() {
|
||||
BeforeEach(func() {
|
||||
Expect(scope.SaveSession).To(BeFalse())
|
||||
scope.SaveSession = true
|
||||
})
|
||||
|
||||
It("returns the updated session", func() {
|
||||
s := middleware.GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
Expect(s.SaveSession).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("without a scope", func() {
|
||||
It("returns nil", func() {
|
||||
Expect(middleware.GetRequestScope(request)).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
@ -9,14 +9,14 @@ import (
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
)
|
||||
|
||||
// MakeCookie constructs a cookie from the given parameters,
|
||||
// discovering the domain from the request if not specified.
|
||||
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
|
||||
if domain != "" {
|
||||
host := util.GetRequestHost(req)
|
||||
host := requestutil.GetRequestHost(req)
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
@ -48,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
|
||||
// If nothing matches, create the cookie with the shortest domain
|
||||
defaultDomain := ""
|
||||
if len(cookieOpts.Domains) > 0 {
|
||||
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
|
||||
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", requestutil.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
|
||||
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
|
||||
}
|
||||
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
|
||||
@ -57,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
|
||||
// GetCookieDomain returns the correct cookie domain given a list of domains
|
||||
// by checking the X-Fowarded-Host and host header of an an http request
|
||||
func GetCookieDomain(req *http.Request, cookieDomains []string) string {
|
||||
host := util.GetRequestHost(req)
|
||||
host := requestutil.GetRequestHost(req)
|
||||
for _, domain := range cookieDomains {
|
||||
if strings.HasSuffix(host, domain) {
|
||||
return domain
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
)
|
||||
|
||||
// AuthStatus defines the different types of auth logging that occur
|
||||
@ -197,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu
|
||||
|
||||
err := l.authTemplate.Execute(l.writer, authLogMessageData{
|
||||
Client: client,
|
||||
Host: util.GetRequestHost(req),
|
||||
Host: requestutil.GetRequestHost(req),
|
||||
Protocol: req.Proto,
|
||||
RequestMethod: req.Method,
|
||||
Timestamp: FormatTimestamp(now),
|
||||
@ -251,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.
|
||||
|
||||
err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
|
||||
Client: client,
|
||||
Host: util.GetRequestHost(req),
|
||||
Host: requestutil.GetRequestHost(req),
|
||||
Protocol: req.Proto,
|
||||
RequestDuration: fmt.Sprintf("%0.3f", duration),
|
||||
RequestMethod: req.Method,
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
@ -23,7 +24,7 @@ func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor {
|
||||
// If a session was loaded by a previous handler, it will not be replaced.
|
||||
func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
if scope.Session != nil {
|
||||
|
@ -1,7 +1,6 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -40,8 +39,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
|
||||
// Set up the request with the authorization header and a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
req.Header.Set("Authorization", in.authorizationHeader)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@ -57,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
|
||||
// from the scope
|
||||
var gotSession *sessionsapi.SessionState
|
||||
handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
||||
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||
}))
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header"
|
||||
)
|
||||
@ -61,7 +62,7 @@ func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, erro
|
||||
|
||||
func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
@ -92,7 +93,7 @@ func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, err
|
||||
|
||||
func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
|
@ -1,7 +1,6 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -31,8 +30,7 @@ var _ = Describe("Headers Suite", func() {
|
||||
|
||||
// Set up the request with a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
req.Header = in.initialHeaders.Clone()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@ -218,8 +216,7 @@ var _ = Describe("Headers Suite", func() {
|
||||
|
||||
// Set up the request with a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
for key, values := range in.initialHeaders {
|
||||
|
@ -37,7 +37,7 @@ type jwtSessionLoader struct {
|
||||
// If a session was loaded by a previous handler, it will not be replaced.
|
||||
func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
if scope.Session != nil {
|
||||
|
@ -103,8 +103,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
|
||||
// Set up the request with the authorization header and a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
req.Header.Set("Authorization", in.authorizationHeader)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@ -116,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
|
||||
// from the scope
|
||||
var gotSession *sessionsapi.SessionState
|
||||
handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
||||
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||
}))
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
)
|
||||
|
||||
const httpsScheme = "https"
|
||||
@ -26,10 +26,11 @@ func NewRedirectToHTTPS(httpsPort string) alice.Constructor {
|
||||
// to the port from the httpsAddress given.
|
||||
func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
proto := req.Header.Get("X-Forwarded-Proto")
|
||||
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") {
|
||||
// Only care about the connection to us being HTTPS if the proto is empty,
|
||||
// otherwise the proto is source of truth
|
||||
proto := requestutil.GetRequestProto(req)
|
||||
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == req.URL.Scheme) {
|
||||
// Only care about the connection to us being HTTPS if the proto wasn't
|
||||
// from a trusted `X-Forwarded-Proto` (proto == req.URL.Scheme).
|
||||
// Otherwise the proto is source of truth
|
||||
next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
@ -41,7 +42,7 @@ func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
||||
|
||||
// Set the Host in case the targetURL still does not have one
|
||||
// or it isn't X-Forwarded-Host aware
|
||||
targetURL.Host = util.GetRequestHost(req)
|
||||
targetURL.Host = requestutil.GetRequestHost(req)
|
||||
|
||||
// Overwrite the port if the original request was to a non-standard port
|
||||
if targetURL.Port() != "" {
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
@ -21,6 +22,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
requestString string
|
||||
useTLS bool
|
||||
headers map[string]string
|
||||
reverseProxy bool
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
expectedLocation string
|
||||
@ -35,6 +37,10 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
if in.useTLS {
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
}
|
||||
scope := &middlewareapi.RequestScope{
|
||||
ReverseProxy: in.reverseProxy,
|
||||
}
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@ -52,6 +58,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{},
|
||||
reverseProxy: false,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
@ -60,6 +67,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{},
|
||||
reverseProxy: false,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
@ -69,15 +77,28 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTPS",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("without TLS and X-Forwarded-Proto=HTTPS but ReverseProxy not set", &requestTableInput{
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTPS",
|
||||
},
|
||||
reverseProxy: false,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTPS",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
@ -87,6 +108,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
@ -96,6 +118,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
@ -105,6 +128,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTP",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
@ -115,6 +139,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTP",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
@ -125,6 +150,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "http",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
@ -135,6 +161,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "http",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
@ -143,6 +170,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
requestString: "http://example.com:8080",
|
||||
useTLS: false,
|
||||
headers: map[string]string{},
|
||||
reverseProxy: false,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com:8443"),
|
||||
expectedLocation: "https://example.com:8443",
|
||||
@ -151,6 +179,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
requestString: "https://example.com:8443",
|
||||
useTLS: true,
|
||||
headers: map[string]string{},
|
||||
reverseProxy: false,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
@ -161,6 +190,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
requestString: "/",
|
||||
useTLS: false,
|
||||
expectedStatus: 308,
|
||||
reverseProxy: false,
|
||||
expectedBody: permanentRedirectBody("https://example.com/"),
|
||||
expectedLocation: "https://example.com/",
|
||||
}),
|
||||
@ -171,6 +201,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
"X-Forwarded-Proto": "HTTP",
|
||||
"X-Forwarded-Host": "external.example.com",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://external.example.com"),
|
||||
expectedLocation: "https://external.example.com",
|
||||
|
@ -1,39 +1,20 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
)
|
||||
|
||||
type scopeKey string
|
||||
|
||||
// requestScopeKey uses a typed string to reduce likelihood of clasing
|
||||
// with other context keys
|
||||
const requestScopeKey scopeKey = "request-scope"
|
||||
|
||||
func NewScope() alice.Constructor {
|
||||
return addScope
|
||||
}
|
||||
|
||||
// addScope injects a new request scope into the request context.
|
||||
func addScope(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := &middlewareapi.RequestScope{}
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
requestWithScope := req.WithContext(contextWithScope)
|
||||
next.ServeHTTP(rw, requestWithScope)
|
||||
})
|
||||
}
|
||||
|
||||
// GetRequestScope returns the current request scope from the given request
|
||||
func GetRequestScope(req *http.Request) *middlewareapi.RequestScope {
|
||||
scope := req.Context().Value(requestScopeKey)
|
||||
if scope == nil {
|
||||
return nil
|
||||
func NewScope(reverseProxy bool) alice.Constructor {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := &middlewareapi.RequestScope{
|
||||
ReverseProxy: reverseProxy,
|
||||
}
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
next.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
|
||||
return scope.(*middlewareapi.RequestScope)
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
@ -21,73 +20,49 @@ var _ = Describe("Scope Suite", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
rw = httptest.NewRecorder()
|
||||
|
||||
handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextRequest = r
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
handler.ServeHTTP(rw, request)
|
||||
})
|
||||
|
||||
It("does not add a scope to the original request", func() {
|
||||
Expect(request.Context().Value(requestScopeKey)).To(BeNil())
|
||||
})
|
||||
|
||||
It("cannot load a scope from the original request using GetRequestScope", func() {
|
||||
Expect(GetRequestScope(request)).To(BeNil())
|
||||
})
|
||||
|
||||
It("adds a scope to the request for the next handler", func() {
|
||||
Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("can load a scope from the next handler's request using GetRequestScope", func() {
|
||||
Expect(GetRequestScope(nextRequest)).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetRequestScope", func() {
|
||||
var request *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("with a scope", func() {
|
||||
var scope *middlewareapi.RequestScope
|
||||
|
||||
Context("ReverseProxy is false", func() {
|
||||
BeforeEach(func() {
|
||||
scope = &middlewareapi.RequestScope{}
|
||||
contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope)
|
||||
request = request.WithContext(contextWithScope)
|
||||
handler := NewScope(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextRequest = r
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
handler.ServeHTTP(rw, request)
|
||||
})
|
||||
|
||||
It("returns the scope", func() {
|
||||
s := GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
It("does not add a scope to the original request", func() {
|
||||
Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil())
|
||||
})
|
||||
|
||||
Context("if the scope is then modified", func() {
|
||||
BeforeEach(func() {
|
||||
Expect(scope.SaveSession).To(BeFalse())
|
||||
scope.SaveSession = true
|
||||
})
|
||||
It("cannot load a scope from the original request using GetRequestScope", func() {
|
||||
Expect(middlewareapi.GetRequestScope(request)).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns the updated session", func() {
|
||||
s := GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
Expect(s.SaveSession).To(BeTrue())
|
||||
})
|
||||
It("adds a scope to the request for the next handler", func() {
|
||||
Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("can load a scope from the next handler's request using GetRequestScope", func() {
|
||||
scope := middlewareapi.GetRequestScope(nextRequest)
|
||||
Expect(scope).ToNot(BeNil())
|
||||
Expect(scope.ReverseProxy).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("without a scope", func() {
|
||||
It("returns nil", func() {
|
||||
Expect(GetRequestScope(request)).To(BeNil())
|
||||
Context("ReverseProxy is true", func() {
|
||||
BeforeEach(func() {
|
||||
handler := NewScope(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextRequest = r
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
handler.ServeHTTP(rw, request)
|
||||
})
|
||||
|
||||
It("return a scope where the ReverseProxy field is true", func() {
|
||||
scope := middlewareapi.GetRequestScope(nextRequest)
|
||||
Expect(scope).ToNot(BeNil())
|
||||
Expect(scope.ReverseProxy).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
)
|
||||
@ -59,7 +60,7 @@ type storedSessionLoader struct {
|
||||
// If a session was loader by a previous handler, it will not be replaced.
|
||||
func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
if scope.Session != nil {
|
||||
|
@ -104,8 +104,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
// Set up the request with the request headesr and a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
req.Header = in.requestHeaders
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@ -120,7 +119,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
// from the scope
|
||||
var gotSession *sessionsapi.SessionState
|
||||
handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
||||
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||
}))
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
|
48
pkg/requests/util/util.go
Normal file
48
pkg/requests/util/util.go
Normal file
@ -0,0 +1,48 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
)
|
||||
|
||||
// GetRequestProto returns the request scheme or X-Forwarded-Proto if present
|
||||
// and the request is proxied.
|
||||
func GetRequestProto(req *http.Request) string {
|
||||
proto := req.Header.Get("X-Forwarded-Proto")
|
||||
if !IsProxied(req) || proto == "" {
|
||||
proto = req.URL.Scheme
|
||||
}
|
||||
return proto
|
||||
}
|
||||
|
||||
// GetRequestHost returns the request host header or X-Forwarded-Host if
|
||||
// present and the request is proxied.
|
||||
func GetRequestHost(req *http.Request) string {
|
||||
host := req.Header.Get("X-Forwarded-Host")
|
||||
if !IsProxied(req) || host == "" {
|
||||
host = req.Host
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// GetRequestURI return the request URI or X-Forwarded-Uri if present and the
|
||||
// request is proxied.
|
||||
func GetRequestURI(req *http.Request) string {
|
||||
uri := req.Header.Get("X-Forwarded-Uri")
|
||||
if !IsProxied(req) || uri == "" {
|
||||
// Use RequestURI to preserve ?query
|
||||
uri = req.URL.RequestURI()
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
// IsProxied determines if a request was from a proxy based on the RequestScope
|
||||
// ReverseProxy tracker.
|
||||
func IsProxied(req *http.Request) bool {
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
if scope == nil {
|
||||
return false
|
||||
}
|
||||
return scope.ReverseProxy
|
||||
}
|
19
pkg/requests/util/util_suite_test.go
Normal file
19
pkg/requests/util/util_suite_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package util_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// TestRequestUtilSuite and related tests are in a *_test package
|
||||
// to prevent circular imports with the `logger` package which uses
|
||||
// this functionality
|
||||
func TestRequestUtilSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Request Utils")
|
||||
}
|
131
pkg/requests/util/util_test.go
Normal file
131
pkg/requests/util/util_test.go
Normal file
@ -0,0 +1,131 @@
|
||||
package util_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Util Suite", func() {
|
||||
const (
|
||||
proto = "http"
|
||||
host = "www.oauth2proxy.test"
|
||||
uri = "/test/endpoint"
|
||||
)
|
||||
var req *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
req = httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
fmt.Sprintf("%s://%s%s", proto, host, uri),
|
||||
nil,
|
||||
)
|
||||
})
|
||||
|
||||
Context("GetRequestHost", func() {
|
||||
Context("IsProxied is false", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||
})
|
||||
|
||||
It("returns the host", func() {
|
||||
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||
})
|
||||
|
||||
It("ignores X-Forwarded-Host and returns the host", func() {
|
||||
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
|
||||
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IsProxied is true", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: true,
|
||||
})
|
||||
})
|
||||
|
||||
It("returns the host if X-Forwarded-Host is not present", func() {
|
||||
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||
})
|
||||
|
||||
It("returns the X-Forwarded-Host when present", func() {
|
||||
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
|
||||
Expect(util.GetRequestHost(req)).To(Equal("external.oauth2proxy.text"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetRequestProto", func() {
|
||||
Context("IsProxied is false", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||
})
|
||||
|
||||
It("returns the scheme", func() {
|
||||
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||
})
|
||||
|
||||
It("ignores X-Forwarded-Proto and returns the scheme", func() {
|
||||
req.Header.Add("X-Forwarded-Proto", "https")
|
||||
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IsProxied is true", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: true,
|
||||
})
|
||||
})
|
||||
|
||||
It("returns the scheme if X-Forwarded-Proto is not present", func() {
|
||||
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||
})
|
||||
|
||||
It("returns the X-Forwarded-Proto when present", func() {
|
||||
req.Header.Add("X-Forwarded-Proto", "https")
|
||||
Expect(util.GetRequestProto(req)).To(Equal("https"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetRequestURI", func() {
|
||||
Context("IsProxied is false", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||
})
|
||||
|
||||
It("returns the URI", func() {
|
||||
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||
})
|
||||
|
||||
It("ignores X-Forwarded-Uri and returns the URI", func() {
|
||||
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
|
||||
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IsProxied is true", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: true,
|
||||
})
|
||||
})
|
||||
|
||||
It("returns the URI if X-Forwarded-Uri is not present", func() {
|
||||
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||
})
|
||||
|
||||
It("returns the X-Forwarded-Uri when present", func() {
|
||||
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
|
||||
Expect(util.GetRequestURI(req)).To(Equal("/some/other/path"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
@ -4,7 +4,6 @@ import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func GetCertPool(paths []string) (*x509.CertPool, error) {
|
||||
@ -24,31 +23,3 @@ func GetCertPool(paths []string) (*x509.CertPool, error) {
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// GetRequestProto return the request host header or X-Forwarded-Proto if present
|
||||
func GetRequestProto(req *http.Request) string {
|
||||
proto := req.Header.Get("X-Forwarded-Proto")
|
||||
if proto == "" {
|
||||
proto = req.URL.Scheme
|
||||
}
|
||||
return proto
|
||||
}
|
||||
|
||||
// GetRequestHost return the request host header or X-Forwarded-Host if present
|
||||
func GetRequestHost(req *http.Request) string {
|
||||
host := req.Header.Get("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = req.Host
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// GetRequestURI return the request host header or X-Forwarded-Uri if present
|
||||
func GetRequestURI(req *http.Request) string {
|
||||
uri := req.Header.Get("X-Forwarded-Uri")
|
||||
if uri == "" {
|
||||
// Use RequestURI to preserve ?query
|
||||
uri = req.URL.RequestURI()
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
@ -4,11 +4,9 @@ import (
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -97,42 +95,3 @@ func TestGetCertPool(t *testing.T) {
|
||||
expectedSubjects := []string{testCA1Subj, testCA2Subj}
|
||||
assert.Equal(t, expectedSubjects, got)
|
||||
}
|
||||
|
||||
func TestGetRequestHost(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com", nil)
|
||||
host := GetRequestHost(req)
|
||||
g.Expect(host).To(Equal("example.com"))
|
||||
|
||||
proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil)
|
||||
proxyReq.Header.Add("X-Forwarded-Host", "external.example.com")
|
||||
extHost := GetRequestHost(proxyReq)
|
||||
g.Expect(extHost).To(Equal("external.example.com"))
|
||||
}
|
||||
|
||||
func TestGetRequestProto(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com", nil)
|
||||
proto := GetRequestProto(req)
|
||||
g.Expect(proto).To(Equal("https"))
|
||||
|
||||
proxyReq := httptest.NewRequest("GET", "https://internal.example.com", nil)
|
||||
proxyReq.Header.Add("X-Forwarded-Proto", "http")
|
||||
extProto := GetRequestProto(proxyReq)
|
||||
g.Expect(extProto).To(Equal("http"))
|
||||
}
|
||||
|
||||
func TestGetRequestURI(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/ping", nil)
|
||||
uri := GetRequestURI(req)
|
||||
g.Expect(uri).To(Equal("/ping"))
|
||||
|
||||
proxyReq := httptest.NewRequest("GET", "http://internal.example.com/bong", nil)
|
||||
proxyReq.Header.Add("X-Forwarded-Uri", "/ping")
|
||||
extURI := GetRequestURI(proxyReq)
|
||||
g.Expect(extURI).To(Equal("/ping"))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user