1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-26 05:27:28 +02:00

Refactor organization of scope aware request utils

Reorganized the structure of the Request Utils due to their widespread use
resulting in circular imports issues (mostly because of middleware & logger).
This commit is contained in:
Nick Meves 2021-01-02 13:16:01 -08:00
parent b625de9490
commit 6fb3274ca3
No known key found for this signature in database
GPG Key ID: 93BA8A3CEDCDD1CF
20 changed files with 357 additions and 185 deletions

View File

@ -0,0 +1,19 @@
package middleware_test
import (
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
// TestMiddlewareSuite and related tests are in a *_test package
// to prevent circular imports with the `logger` package which uses
// this functionality
func TestMiddlewareSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Middleware API")
}

View File

@ -1,9 +1,18 @@
package middleware package middleware
import ( import (
"context"
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
) )
type scopeKey string
// RequestScopeKey uses a typed string to reduce likelihood of clashing
// with other context keys
const RequestScopeKey scopeKey = "request-scope"
// RequestScope contains information regarding the request that is being made. // RequestScope contains information regarding the request that is being made.
// The RequestScope is used to pass information between different middlewares // The RequestScope is used to pass information between different middlewares
// within the chain. // within the chain.
@ -26,3 +35,19 @@ type RequestScope struct {
// it was loaded or not. // it was loaded or not.
SessionRevalidated bool SessionRevalidated bool
} }
// GetRequestScope returns the current request scope from the given request
func GetRequestScope(req *http.Request) *RequestScope {
scope := req.Context().Value(RequestScopeKey)
if scope == nil {
return nil
}
return scope.(*RequestScope)
}
// AddRequestScope adds a RequestScope to a request
func AddRequestScope(req *http.Request, scope *RequestScope) *http.Request {
ctx := context.WithValue(req.Context(), RequestScopeKey, scope)
return req.WithContext(ctx)
}

View File

@ -0,0 +1,56 @@
package middleware_test
import (
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Scope Suite", func() {
Context("GetRequestScope", func() {
var request *http.Request
BeforeEach(func() {
var err error
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
Expect(err).ToNot(HaveOccurred())
})
Context("with a scope", func() {
var scope *middleware.RequestScope
BeforeEach(func() {
scope = &middleware.RequestScope{}
request = middleware.AddRequestScope(request, scope)
})
It("returns the scope", func() {
s := middleware.GetRequestScope(request)
Expect(s).ToNot(BeNil())
Expect(s).To(Equal(scope))
})
Context("if the scope is then modified", func() {
BeforeEach(func() {
Expect(scope.SaveSession).To(BeFalse())
scope.SaveSession = true
})
It("returns the updated session", func() {
s := middleware.GetRequestScope(request)
Expect(s).ToNot(BeNil())
Expect(s).To(Equal(scope))
Expect(s.SaveSession).To(BeTrue())
})
})
})
Context("without a scope", func() {
It("returns nil", func() {
Expect(middleware.GetRequestScope(request)).To(BeNil())
})
})
})
})

View File

@ -9,14 +9,14 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
) )
// MakeCookie constructs a cookie from the given parameters, // MakeCookie constructs a cookie from the given parameters,
// discovering the domain from the request if not specified. // discovering the domain from the request if not specified.
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie { func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
if domain != "" { if domain != "" {
host := util.GetRequestHost(req) host := requestutil.GetRequestHost(req)
if h, _, err := net.SplitHostPort(host); err == nil { if h, _, err := net.SplitHostPort(host); err == nil {
host = h host = h
} }
@ -48,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
// If nothing matches, create the cookie with the shortest domain // If nothing matches, create the cookie with the shortest domain
defaultDomain := "" defaultDomain := ""
if len(cookieOpts.Domains) > 0 { if len(cookieOpts.Domains) > 0 {
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", requestutil.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1] defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
} }
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite)) return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
@ -57,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
// GetCookieDomain returns the correct cookie domain given a list of domains // GetCookieDomain returns the correct cookie domain given a list of domains
// by checking the X-Fowarded-Host and host header of an an http request // by checking the X-Fowarded-Host and host header of an an http request
func GetCookieDomain(req *http.Request, cookieDomains []string) string { func GetCookieDomain(req *http.Request, cookieDomains []string) string {
host := util.GetRequestHost(req) host := requestutil.GetRequestHost(req)
for _, domain := range cookieDomains { for _, domain := range cookieDomains {
if strings.HasSuffix(host, domain) { if strings.HasSuffix(host, domain) {
return domain return domain

View File

@ -12,7 +12,7 @@ import (
"text/template" "text/template"
"time" "time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
) )
// AuthStatus defines the different types of auth logging that occur // AuthStatus defines the different types of auth logging that occur
@ -197,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu
err := l.authTemplate.Execute(l.writer, authLogMessageData{ err := l.authTemplate.Execute(l.writer, authLogMessageData{
Client: client, Client: client,
Host: util.GetRequestHost(req), Host: requestutil.GetRequestHost(req),
Protocol: req.Proto, Protocol: req.Proto,
RequestMethod: req.Method, RequestMethod: req.Method,
Timestamp: FormatTimestamp(now), Timestamp: FormatTimestamp(now),
@ -251,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.
err := l.reqTemplate.Execute(l.writer, reqLogMessageData{ err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
Client: client, Client: client,
Host: util.GetRequestHost(req), Host: requestutil.GetRequestHost(req),
Protocol: req.Proto, Protocol: req.Proto,
RequestDuration: fmt.Sprintf("%0.3f", duration), RequestDuration: fmt.Sprintf("%0.3f", duration),
RequestMethod: req.Method, RequestMethod: req.Method,

View File

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"github.com/justinas/alice" "github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
@ -23,7 +24,7 @@ func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor {
// If a session was loaded by a previous handler, it will not be replaced. // If a session was loaded by a previous handler, it will not be replaced.
func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler { func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := GetRequestScope(req) scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic. // If scope is nil, this will panic.
// A scope should always be injected before this handler is called. // A scope should always be injected before this handler is called.
if scope.Session != nil { if scope.Session != nil {

View File

@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -40,8 +39,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
// Set up the request with the authorization header and a request scope // Set up the request with the authorization header and a request scope
req := httptest.NewRequest("", "/", nil) req := httptest.NewRequest("", "/", nil)
req.Header.Set("Authorization", in.authorizationHeader) req.Header.Set("Authorization", in.authorizationHeader)
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) req = middlewareapi.AddRequestScope(req, scope)
req = req.WithContext(contextWithScope)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -57,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
// from the scope // from the scope
var gotSession *sessionsapi.SessionState var gotSession *sessionsapi.SessionState
handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session gotSession = middlewareapi.GetRequestScope(r).Session
})) }))
handler.ServeHTTP(rw, req) handler.ServeHTTP(rw, req)

View File

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"github.com/justinas/alice" "github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header"
) )
@ -61,7 +62,7 @@ func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, erro
func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler { func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := GetRequestScope(req) scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic. // If scope is nil, this will panic.
// A scope should always be injected before this handler is called. // A scope should always be injected before this handler is called.
@ -92,7 +93,7 @@ func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, err
func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler { func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := GetRequestScope(req) scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic. // If scope is nil, this will panic.
// A scope should always be injected before this handler is called. // A scope should always be injected before this handler is called.

View File

@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"context"
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -31,8 +30,7 @@ var _ = Describe("Headers Suite", func() {
// Set up the request with a request scope // Set up the request with a request scope
req := httptest.NewRequest("", "/", nil) req := httptest.NewRequest("", "/", nil)
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) req = middlewareapi.AddRequestScope(req, scope)
req = req.WithContext(contextWithScope)
req.Header = in.initialHeaders.Clone() req.Header = in.initialHeaders.Clone()
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -218,8 +216,7 @@ var _ = Describe("Headers Suite", func() {
// Set up the request with a request scope // Set up the request with a request scope
req := httptest.NewRequest("", "/", nil) req := httptest.NewRequest("", "/", nil)
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) req = middlewareapi.AddRequestScope(req, scope)
req = req.WithContext(contextWithScope)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
for key, values := range in.initialHeaders { for key, values := range in.initialHeaders {

View File

@ -37,7 +37,7 @@ type jwtSessionLoader struct {
// If a session was loaded by a previous handler, it will not be replaced. // If a session was loaded by a previous handler, it will not be replaced.
func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler { func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := GetRequestScope(req) scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic. // If scope is nil, this will panic.
// A scope should always be injected before this handler is called. // A scope should always be injected before this handler is called.
if scope.Session != nil { if scope.Session != nil {

View File

@ -103,8 +103,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
// Set up the request with the authorization header and a request scope // Set up the request with the authorization header and a request scope
req := httptest.NewRequest("", "/", nil) req := httptest.NewRequest("", "/", nil)
req.Header.Set("Authorization", in.authorizationHeader) req.Header.Set("Authorization", in.authorizationHeader)
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) req = middlewareapi.AddRequestScope(req, scope)
req = req.WithContext(contextWithScope)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -116,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
// from the scope // from the scope
var gotSession *sessionsapi.SessionState var gotSession *sessionsapi.SessionState
handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session gotSession = middlewareapi.GetRequestScope(r).Session
})) }))
handler.ServeHTTP(rw, req) handler.ServeHTTP(rw, req)

View File

@ -1,39 +1,20 @@
package middleware package middleware
import ( import (
"context"
"net/http" "net/http"
"github.com/justinas/alice" "github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
) )
type scopeKey string func NewScope(reverseProxy bool) alice.Constructor {
// requestScopeKey uses a typed string to reduce likelihood of clashing
// with other context keys
const requestScopeKey scopeKey = "request-scope"
func NewScope(opts *options.Options) alice.Constructor {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := &middlewareapi.RequestScope{ scope := &middlewareapi.RequestScope{
ReverseProxy: opts.ReverseProxy, ReverseProxy: reverseProxy,
} }
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) req = middlewareapi.AddRequestScope(req, scope)
requestWithScope := req.WithContext(contextWithScope) next.ServeHTTP(rw, req)
next.ServeHTTP(rw, requestWithScope)
}) })
} }
} }
// GetRequestScope returns the current request scope from the given request
func GetRequestScope(req *http.Request) *middlewareapi.RequestScope {
scope := req.Context().Value(requestScopeKey)
if scope == nil {
return nil
}
return scope.(*middlewareapi.RequestScope)
}

View File

@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -21,73 +20,49 @@ var _ = Describe("Scope Suite", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
rw = httptest.NewRecorder() rw = httptest.NewRecorder()
handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextRequest = r
w.WriteHeader(200)
}))
handler.ServeHTTP(rw, request)
}) })
It("does not add a scope to the original request", func() { Context("ReverseProxy is false", func() {
Expect(request.Context().Value(requestScopeKey)).To(BeNil())
})
It("cannot load a scope from the original request using GetRequestScope", func() {
Expect(GetRequestScope(request)).To(BeNil())
})
It("adds a scope to the request for the next handler", func() {
Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil())
})
It("can load a scope from the next handler's request using GetRequestScope", func() {
Expect(GetRequestScope(nextRequest)).ToNot(BeNil())
})
})
Context("GetRequestScope", func() {
var request *http.Request
BeforeEach(func() {
var err error
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
Expect(err).ToNot(HaveOccurred())
})
Context("with a scope", func() {
var scope *middlewareapi.RequestScope
BeforeEach(func() { BeforeEach(func() {
scope = &middlewareapi.RequestScope{} handler := NewScope(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope) nextRequest = r
request = request.WithContext(contextWithScope) w.WriteHeader(200)
}))
handler.ServeHTTP(rw, request)
}) })
It("returns the scope", func() { It("does not add a scope to the original request", func() {
s := GetRequestScope(request) Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil())
Expect(s).ToNot(BeNil())
Expect(s).To(Equal(scope))
}) })
Context("if the scope is then modified", func() { It("cannot load a scope from the original request using GetRequestScope", func() {
BeforeEach(func() { Expect(middlewareapi.GetRequestScope(request)).To(BeNil())
Expect(scope.SaveSession).To(BeFalse()) })
scope.SaveSession = true
})
It("returns the updated session", func() { It("adds a scope to the request for the next handler", func() {
s := GetRequestScope(request) Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil())
Expect(s).ToNot(BeNil()) })
Expect(s).To(Equal(scope))
Expect(s.SaveSession).To(BeTrue()) It("can load a scope from the next handler's request using GetRequestScope", func() {
}) scope := middlewareapi.GetRequestScope(nextRequest)
Expect(scope).ToNot(BeNil())
Expect(scope.ReverseProxy).To(BeFalse())
}) })
}) })
Context("without a scope", func() { Context("ReverseProxy is true", func() {
It("returns nil", func() { BeforeEach(func() {
Expect(GetRequestScope(request)).To(BeNil()) handler := NewScope(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextRequest = r
w.WriteHeader(200)
}))
handler.ServeHTTP(rw, request)
})
It("return a scope where the ReverseProxy field is true", func() {
scope := middlewareapi.GetRequestScope(nextRequest)
Expect(scope).ToNot(BeNil())
Expect(scope.ReverseProxy).To(BeTrue())
}) })
}) })
}) })

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/justinas/alice" "github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
) )
@ -59,7 +60,7 @@ type storedSessionLoader struct {
// If a session was loader by a previous handler, it will not be replaced. // If a session was loader by a previous handler, it will not be replaced.
func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := GetRequestScope(req) scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic. // If scope is nil, this will panic.
// A scope should always be injected before this handler is called. // A scope should always be injected before this handler is called.
if scope.Session != nil { if scope.Session != nil {

View File

@ -104,8 +104,7 @@ var _ = Describe("Stored Session Suite", func() {
// Set up the request with the request headesr and a request scope // Set up the request with the request headesr and a request scope
req := httptest.NewRequest("", "/", nil) req := httptest.NewRequest("", "/", nil)
req.Header = in.requestHeaders req.Header = in.requestHeaders
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) req = middlewareapi.AddRequestScope(req, scope)
req = req.WithContext(contextWithScope)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -120,7 +119,7 @@ var _ = Describe("Stored Session Suite", func() {
// from the scope // from the scope
var gotSession *sessionsapi.SessionState var gotSession *sessionsapi.SessionState
handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session gotSession = middlewareapi.GetRequestScope(r).Session
})) }))
handler.ServeHTTP(rw, req) handler.ServeHTTP(rw, req)

48
pkg/requests/util/util.go Normal file
View File

@ -0,0 +1,48 @@
package util
import (
"net/http"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
)
// GetRequestProto returns the request scheme or X-Forwarded-Proto if present
// and the request is proxied.
func GetRequestProto(req *http.Request) string {
proto := req.Header.Get("X-Forwarded-Proto")
if !IsProxied(req) || proto == "" {
proto = req.URL.Scheme
}
return proto
}
// GetRequestHost returns the request host header or X-Forwarded-Host if
// present and the request is proxied.
func GetRequestHost(req *http.Request) string {
host := req.Header.Get("X-Forwarded-Host")
if !IsProxied(req) || host == "" {
host = req.Host
}
return host
}
// GetRequestURI return the request URI or X-Forwarded-Uri if present and the
// request is proxied.
func GetRequestURI(req *http.Request) string {
uri := req.Header.Get("X-Forwarded-Uri")
if !IsProxied(req) || uri == "" {
// Use RequestURI to preserve ?query
uri = req.URL.RequestURI()
}
return uri
}
// IsProxied determines if a request was from a proxy based on the RequestScope
// ReverseProxy tracker.
func IsProxied(req *http.Request) bool {
scope := middlewareapi.GetRequestScope(req)
if scope == nil {
return false
}
return scope.ReverseProxy
}

View File

@ -0,0 +1,19 @@
package util_test
import (
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
// TestRequestUtilSuite and related tests are in a *_test package
// to prevent circular imports with the `logger` package which uses
// this functionality
func TestRequestUtilSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Request Utils")
}

View File

@ -0,0 +1,131 @@
package util_test
import (
"fmt"
"net/http"
"net/http/httptest"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Util Suite", func() {
const (
proto = "http"
host = "www.oauth2proxy.test"
uri = "/test/endpoint"
)
var req *http.Request
BeforeEach(func() {
req = httptest.NewRequest(
http.MethodGet,
fmt.Sprintf("%s://%s%s", proto, host, uri),
nil,
)
})
Context("GetRequestHost", func() {
Context("IsProxied is false", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
})
It("returns the host", func() {
Expect(util.GetRequestHost(req)).To(Equal(host))
})
It("ignores X-Forwarded-Host and returns the host", func() {
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
Expect(util.GetRequestHost(req)).To(Equal(host))
})
})
Context("IsProxied is true", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{
ReverseProxy: true,
})
})
It("returns the host if X-Forwarded-Host is not present", func() {
Expect(util.GetRequestHost(req)).To(Equal(host))
})
It("returns the X-Forwarded-Host when present", func() {
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
Expect(util.GetRequestHost(req)).To(Equal("external.oauth2proxy.text"))
})
})
})
Context("GetRequestProto", func() {
Context("IsProxied is false", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
})
It("returns the scheme", func() {
Expect(util.GetRequestProto(req)).To(Equal(proto))
})
It("ignores X-Forwarded-Proto and returns the scheme", func() {
req.Header.Add("X-Forwarded-Proto", "https")
Expect(util.GetRequestProto(req)).To(Equal(proto))
})
})
Context("IsProxied is true", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{
ReverseProxy: true,
})
})
It("returns the scheme if X-Forwarded-Proto is not present", func() {
Expect(util.GetRequestProto(req)).To(Equal(proto))
})
It("returns the X-Forwarded-Proto when present", func() {
req.Header.Add("X-Forwarded-Proto", "https")
Expect(util.GetRequestProto(req)).To(Equal("https"))
})
})
})
Context("GetRequestURI", func() {
Context("IsProxied is false", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
})
It("returns the URI", func() {
Expect(util.GetRequestURI(req)).To(Equal(uri))
})
It("ignores X-Forwarded-Uri and returns the URI", func() {
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
Expect(util.GetRequestURI(req)).To(Equal(uri))
})
})
Context("IsProxied is true", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{
ReverseProxy: true,
})
})
It("returns the URI if X-Forwarded-Uri is not present", func() {
Expect(util.GetRequestURI(req)).To(Equal(uri))
})
It("returns the X-Forwarded-Uri when present", func() {
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
Expect(util.GetRequestURI(req)).To(Equal("/some/other/path"))
})
})
})
})

View File

@ -4,9 +4,6 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
) )
func GetCertPool(paths []string) (*x509.CertPool, error) { func GetCertPool(paths []string) (*x509.CertPool, error) {
@ -26,37 +23,3 @@ func GetCertPool(paths []string) (*x509.CertPool, error) {
} }
return pool, nil return pool, nil
} }
// GetRequestProto return the request host header or X-Forwarded-Proto if present
func GetRequestProto(req *http.Request) string {
proto := req.Header.Get("X-Forwarded-Proto")
if !isProxied(req) || proto == "" {
proto = req.URL.Scheme
}
return proto
}
// GetRequestHost return the request host header or X-Forwarded-Host if present
// and reverse proxy mode is enabled.
func GetRequestHost(req *http.Request) string {
host := req.Header.Get("X-Forwarded-Host")
if !isProxied(req) || host == "" {
host = req.Host
}
return host
}
// GetRequestURI return the request host header or X-Forwarded-Uri if present
func GetRequestURI(req *http.Request) string {
uri := req.Header.Get("X-Forwarded-Uri")
if !isProxied(req) || uri == "" {
// Use RequestURI to preserve ?query
uri = req.URL.RequestURI()
}
return uri
}
func isProxied(req *http.Request) bool {
scope := middleware.GetRequestScope(req)
return scope.ReverseProxy
}

View File

@ -4,11 +4,9 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"io/ioutil" "io/ioutil"
"net/http/httptest"
"os" "os"
"testing" "testing"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -97,42 +95,3 @@ func TestGetCertPool(t *testing.T) {
expectedSubjects := []string{testCA1Subj, testCA2Subj} expectedSubjects := []string{testCA1Subj, testCA2Subj}
assert.Equal(t, expectedSubjects, got) assert.Equal(t, expectedSubjects, got)
} }
func TestGetRequestHost(t *testing.T) {
g := NewWithT(t)
req := httptest.NewRequest("GET", "https://example.com", nil)
host := GetRequestHost(req)
g.Expect(host).To(Equal("example.com"))
proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil)
proxyReq.Header.Add("X-Forwarded-Host", "external.example.com")
extHost := GetRequestHost(proxyReq)
g.Expect(extHost).To(Equal("external.example.com"))
}
func TestGetRequestProto(t *testing.T) {
g := NewWithT(t)
req := httptest.NewRequest("GET", "https://example.com", nil)
proto := GetRequestProto(req)
g.Expect(proto).To(Equal("https"))
proxyReq := httptest.NewRequest("GET", "https://internal.example.com", nil)
proxyReq.Header.Add("X-Forwarded-Proto", "http")
extProto := GetRequestProto(proxyReq)
g.Expect(extProto).To(Equal("http"))
}
func TestGetRequestURI(t *testing.T) {
g := NewWithT(t)
req := httptest.NewRequest("GET", "https://example.com/ping", nil)
uri := GetRequestURI(req)
g.Expect(uri).To(Equal("/ping"))
proxyReq := httptest.NewRequest("GET", "http://internal.example.com/bong", nil)
proxyReq.Header.Add("X-Forwarded-Uri", "/ping")
extURI := GetRequestURI(proxyReq)
g.Expect(extURI).To(Equal("/ping"))
}