mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-01-24 05:26:55 +02:00
Refactor organization of scope aware request utils
Reorganized the structure of the Request Utils due to their widespread use resulting in circular imports issues (mostly because of middleware & logger).
This commit is contained in:
parent
b625de9490
commit
6fb3274ca3
19
pkg/apis/middleware/middleware_suite_test.go
Normal file
19
pkg/apis/middleware/middleware_suite_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// TestMiddlewareSuite and related tests are in a *_test package
|
||||
// to prevent circular imports with the `logger` package which uses
|
||||
// this functionality
|
||||
func TestMiddlewareSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Middleware API")
|
||||
}
|
@ -1,9 +1,18 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
type scopeKey string
|
||||
|
||||
// RequestScopeKey uses a typed string to reduce likelihood of clashing
|
||||
// with other context keys
|
||||
const RequestScopeKey scopeKey = "request-scope"
|
||||
|
||||
// RequestScope contains information regarding the request that is being made.
|
||||
// The RequestScope is used to pass information between different middlewares
|
||||
// within the chain.
|
||||
@ -26,3 +35,19 @@ type RequestScope struct {
|
||||
// it was loaded or not.
|
||||
SessionRevalidated bool
|
||||
}
|
||||
|
||||
// GetRequestScope returns the current request scope from the given request
|
||||
func GetRequestScope(req *http.Request) *RequestScope {
|
||||
scope := req.Context().Value(RequestScopeKey)
|
||||
if scope == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return scope.(*RequestScope)
|
||||
}
|
||||
|
||||
// AddRequestScope adds a RequestScope to a request
|
||||
func AddRequestScope(req *http.Request, scope *RequestScope) *http.Request {
|
||||
ctx := context.WithValue(req.Context(), RequestScopeKey, scope)
|
||||
return req.WithContext(ctx)
|
||||
}
|
||||
|
56
pkg/apis/middleware/scope_test.go
Normal file
56
pkg/apis/middleware/scope_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Scope Suite", func() {
|
||||
Context("GetRequestScope", func() {
|
||||
var request *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("with a scope", func() {
|
||||
var scope *middleware.RequestScope
|
||||
|
||||
BeforeEach(func() {
|
||||
scope = &middleware.RequestScope{}
|
||||
request = middleware.AddRequestScope(request, scope)
|
||||
})
|
||||
|
||||
It("returns the scope", func() {
|
||||
s := middleware.GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
})
|
||||
|
||||
Context("if the scope is then modified", func() {
|
||||
BeforeEach(func() {
|
||||
Expect(scope.SaveSession).To(BeFalse())
|
||||
scope.SaveSession = true
|
||||
})
|
||||
|
||||
It("returns the updated session", func() {
|
||||
s := middleware.GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
Expect(s.SaveSession).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("without a scope", func() {
|
||||
It("returns nil", func() {
|
||||
Expect(middleware.GetRequestScope(request)).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
@ -9,14 +9,14 @@ import (
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
)
|
||||
|
||||
// MakeCookie constructs a cookie from the given parameters,
|
||||
// discovering the domain from the request if not specified.
|
||||
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
|
||||
if domain != "" {
|
||||
host := util.GetRequestHost(req)
|
||||
host := requestutil.GetRequestHost(req)
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
@ -48,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
|
||||
// If nothing matches, create the cookie with the shortest domain
|
||||
defaultDomain := ""
|
||||
if len(cookieOpts.Domains) > 0 {
|
||||
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
|
||||
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", requestutil.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
|
||||
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
|
||||
}
|
||||
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
|
||||
@ -57,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
|
||||
// GetCookieDomain returns the correct cookie domain given a list of domains
|
||||
// by checking the X-Fowarded-Host and host header of an an http request
|
||||
func GetCookieDomain(req *http.Request, cookieDomains []string) string {
|
||||
host := util.GetRequestHost(req)
|
||||
host := requestutil.GetRequestHost(req)
|
||||
for _, domain := range cookieDomains {
|
||||
if strings.HasSuffix(host, domain) {
|
||||
return domain
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
)
|
||||
|
||||
// AuthStatus defines the different types of auth logging that occur
|
||||
@ -197,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu
|
||||
|
||||
err := l.authTemplate.Execute(l.writer, authLogMessageData{
|
||||
Client: client,
|
||||
Host: util.GetRequestHost(req),
|
||||
Host: requestutil.GetRequestHost(req),
|
||||
Protocol: req.Proto,
|
||||
RequestMethod: req.Method,
|
||||
Timestamp: FormatTimestamp(now),
|
||||
@ -251,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.
|
||||
|
||||
err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
|
||||
Client: client,
|
||||
Host: util.GetRequestHost(req),
|
||||
Host: requestutil.GetRequestHost(req),
|
||||
Protocol: req.Proto,
|
||||
RequestDuration: fmt.Sprintf("%0.3f", duration),
|
||||
RequestMethod: req.Method,
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
@ -23,7 +24,7 @@ func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor {
|
||||
// If a session was loaded by a previous handler, it will not be replaced.
|
||||
func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
if scope.Session != nil {
|
||||
|
@ -1,7 +1,6 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -40,8 +39,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
|
||||
// Set up the request with the authorization header and a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
req.Header.Set("Authorization", in.authorizationHeader)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@ -57,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
|
||||
// from the scope
|
||||
var gotSession *sessionsapi.SessionState
|
||||
handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
||||
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||
}))
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header"
|
||||
)
|
||||
@ -61,7 +62,7 @@ func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, erro
|
||||
|
||||
func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
@ -92,7 +93,7 @@ func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, err
|
||||
|
||||
func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
|
@ -1,7 +1,6 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -31,8 +30,7 @@ var _ = Describe("Headers Suite", func() {
|
||||
|
||||
// Set up the request with a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
req.Header = in.initialHeaders.Clone()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@ -218,8 +216,7 @@ var _ = Describe("Headers Suite", func() {
|
||||
|
||||
// Set up the request with a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
for key, values := range in.initialHeaders {
|
||||
|
@ -37,7 +37,7 @@ type jwtSessionLoader struct {
|
||||
// If a session was loaded by a previous handler, it will not be replaced.
|
||||
func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
if scope.Session != nil {
|
||||
|
@ -103,8 +103,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
|
||||
// Set up the request with the authorization header and a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
req.Header.Set("Authorization", in.authorizationHeader)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@ -116,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
|
||||
// from the scope
|
||||
var gotSession *sessionsapi.SessionState
|
||||
handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
||||
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||
}))
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
|
@ -1,39 +1,20 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
)
|
||||
|
||||
type scopeKey string
|
||||
|
||||
// requestScopeKey uses a typed string to reduce likelihood of clashing
|
||||
// with other context keys
|
||||
const requestScopeKey scopeKey = "request-scope"
|
||||
|
||||
func NewScope(opts *options.Options) alice.Constructor {
|
||||
func NewScope(reverseProxy bool) alice.Constructor {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := &middlewareapi.RequestScope{
|
||||
ReverseProxy: opts.ReverseProxy,
|
||||
ReverseProxy: reverseProxy,
|
||||
}
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
requestWithScope := req.WithContext(contextWithScope)
|
||||
next.ServeHTTP(rw, requestWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
next.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetRequestScope returns the current request scope from the given request
|
||||
func GetRequestScope(req *http.Request) *middlewareapi.RequestScope {
|
||||
scope := req.Context().Value(requestScopeKey)
|
||||
if scope == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return scope.(*middlewareapi.RequestScope)
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
@ -21,73 +20,49 @@ var _ = Describe("Scope Suite", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
rw = httptest.NewRecorder()
|
||||
|
||||
handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextRequest = r
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
handler.ServeHTTP(rw, request)
|
||||
})
|
||||
|
||||
It("does not add a scope to the original request", func() {
|
||||
Expect(request.Context().Value(requestScopeKey)).To(BeNil())
|
||||
})
|
||||
|
||||
It("cannot load a scope from the original request using GetRequestScope", func() {
|
||||
Expect(GetRequestScope(request)).To(BeNil())
|
||||
})
|
||||
|
||||
It("adds a scope to the request for the next handler", func() {
|
||||
Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("can load a scope from the next handler's request using GetRequestScope", func() {
|
||||
Expect(GetRequestScope(nextRequest)).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetRequestScope", func() {
|
||||
var request *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("with a scope", func() {
|
||||
var scope *middlewareapi.RequestScope
|
||||
|
||||
Context("ReverseProxy is false", func() {
|
||||
BeforeEach(func() {
|
||||
scope = &middlewareapi.RequestScope{}
|
||||
contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope)
|
||||
request = request.WithContext(contextWithScope)
|
||||
handler := NewScope(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextRequest = r
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
handler.ServeHTTP(rw, request)
|
||||
})
|
||||
|
||||
It("returns the scope", func() {
|
||||
s := GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
It("does not add a scope to the original request", func() {
|
||||
Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil())
|
||||
})
|
||||
|
||||
Context("if the scope is then modified", func() {
|
||||
BeforeEach(func() {
|
||||
Expect(scope.SaveSession).To(BeFalse())
|
||||
scope.SaveSession = true
|
||||
})
|
||||
It("cannot load a scope from the original request using GetRequestScope", func() {
|
||||
Expect(middlewareapi.GetRequestScope(request)).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns the updated session", func() {
|
||||
s := GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
Expect(s.SaveSession).To(BeTrue())
|
||||
})
|
||||
It("adds a scope to the request for the next handler", func() {
|
||||
Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("can load a scope from the next handler's request using GetRequestScope", func() {
|
||||
scope := middlewareapi.GetRequestScope(nextRequest)
|
||||
Expect(scope).ToNot(BeNil())
|
||||
Expect(scope.ReverseProxy).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("without a scope", func() {
|
||||
It("returns nil", func() {
|
||||
Expect(GetRequestScope(request)).To(BeNil())
|
||||
Context("ReverseProxy is true", func() {
|
||||
BeforeEach(func() {
|
||||
handler := NewScope(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextRequest = r
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
handler.ServeHTTP(rw, request)
|
||||
})
|
||||
|
||||
It("return a scope where the ReverseProxy field is true", func() {
|
||||
scope := middlewareapi.GetRequestScope(nextRequest)
|
||||
Expect(scope).ToNot(BeNil())
|
||||
Expect(scope.ReverseProxy).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
)
|
||||
@ -59,7 +60,7 @@ type storedSessionLoader struct {
|
||||
// If a session was loader by a previous handler, it will not be replaced.
|
||||
func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
if scope.Session != nil {
|
||||
|
@ -104,8 +104,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
// Set up the request with the request headesr and a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
req.Header = in.requestHeaders
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@ -120,7 +119,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||
// from the scope
|
||||
var gotSession *sessionsapi.SessionState
|
||||
handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
||||
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||
}))
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
|
48
pkg/requests/util/util.go
Normal file
48
pkg/requests/util/util.go
Normal file
@ -0,0 +1,48 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
)
|
||||
|
||||
// GetRequestProto returns the request scheme or X-Forwarded-Proto if present
|
||||
// and the request is proxied.
|
||||
func GetRequestProto(req *http.Request) string {
|
||||
proto := req.Header.Get("X-Forwarded-Proto")
|
||||
if !IsProxied(req) || proto == "" {
|
||||
proto = req.URL.Scheme
|
||||
}
|
||||
return proto
|
||||
}
|
||||
|
||||
// GetRequestHost returns the request host header or X-Forwarded-Host if
|
||||
// present and the request is proxied.
|
||||
func GetRequestHost(req *http.Request) string {
|
||||
host := req.Header.Get("X-Forwarded-Host")
|
||||
if !IsProxied(req) || host == "" {
|
||||
host = req.Host
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// GetRequestURI return the request URI or X-Forwarded-Uri if present and the
|
||||
// request is proxied.
|
||||
func GetRequestURI(req *http.Request) string {
|
||||
uri := req.Header.Get("X-Forwarded-Uri")
|
||||
if !IsProxied(req) || uri == "" {
|
||||
// Use RequestURI to preserve ?query
|
||||
uri = req.URL.RequestURI()
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
// IsProxied determines if a request was from a proxy based on the RequestScope
|
||||
// ReverseProxy tracker.
|
||||
func IsProxied(req *http.Request) bool {
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
if scope == nil {
|
||||
return false
|
||||
}
|
||||
return scope.ReverseProxy
|
||||
}
|
19
pkg/requests/util/util_suite_test.go
Normal file
19
pkg/requests/util/util_suite_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package util_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// TestRequestUtilSuite and related tests are in a *_test package
|
||||
// to prevent circular imports with the `logger` package which uses
|
||||
// this functionality
|
||||
func TestRequestUtilSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Request Utils")
|
||||
}
|
131
pkg/requests/util/util_test.go
Normal file
131
pkg/requests/util/util_test.go
Normal file
@ -0,0 +1,131 @@
|
||||
package util_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Util Suite", func() {
|
||||
const (
|
||||
proto = "http"
|
||||
host = "www.oauth2proxy.test"
|
||||
uri = "/test/endpoint"
|
||||
)
|
||||
var req *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
req = httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
fmt.Sprintf("%s://%s%s", proto, host, uri),
|
||||
nil,
|
||||
)
|
||||
})
|
||||
|
||||
Context("GetRequestHost", func() {
|
||||
Context("IsProxied is false", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||
})
|
||||
|
||||
It("returns the host", func() {
|
||||
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||
})
|
||||
|
||||
It("ignores X-Forwarded-Host and returns the host", func() {
|
||||
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
|
||||
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IsProxied is true", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: true,
|
||||
})
|
||||
})
|
||||
|
||||
It("returns the host if X-Forwarded-Host is not present", func() {
|
||||
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||
})
|
||||
|
||||
It("returns the X-Forwarded-Host when present", func() {
|
||||
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
|
||||
Expect(util.GetRequestHost(req)).To(Equal("external.oauth2proxy.text"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetRequestProto", func() {
|
||||
Context("IsProxied is false", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||
})
|
||||
|
||||
It("returns the scheme", func() {
|
||||
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||
})
|
||||
|
||||
It("ignores X-Forwarded-Proto and returns the scheme", func() {
|
||||
req.Header.Add("X-Forwarded-Proto", "https")
|
||||
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IsProxied is true", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: true,
|
||||
})
|
||||
})
|
||||
|
||||
It("returns the scheme if X-Forwarded-Proto is not present", func() {
|
||||
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||
})
|
||||
|
||||
It("returns the X-Forwarded-Proto when present", func() {
|
||||
req.Header.Add("X-Forwarded-Proto", "https")
|
||||
Expect(util.GetRequestProto(req)).To(Equal("https"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetRequestURI", func() {
|
||||
Context("IsProxied is false", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||
})
|
||||
|
||||
It("returns the URI", func() {
|
||||
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||
})
|
||||
|
||||
It("ignores X-Forwarded-Uri and returns the URI", func() {
|
||||
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
|
||||
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IsProxied is true", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: true,
|
||||
})
|
||||
})
|
||||
|
||||
It("returns the URI if X-Forwarded-Uri is not present", func() {
|
||||
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||
})
|
||||
|
||||
It("returns the X-Forwarded-Uri when present", func() {
|
||||
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
|
||||
Expect(util.GetRequestURI(req)).To(Equal("/some/other/path"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
@ -4,9 +4,6 @@ import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||
)
|
||||
|
||||
func GetCertPool(paths []string) (*x509.CertPool, error) {
|
||||
@ -26,37 +23,3 @@ func GetCertPool(paths []string) (*x509.CertPool, error) {
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// GetRequestProto return the request host header or X-Forwarded-Proto if present
|
||||
func GetRequestProto(req *http.Request) string {
|
||||
proto := req.Header.Get("X-Forwarded-Proto")
|
||||
if !isProxied(req) || proto == "" {
|
||||
proto = req.URL.Scheme
|
||||
}
|
||||
return proto
|
||||
}
|
||||
|
||||
// GetRequestHost return the request host header or X-Forwarded-Host if present
|
||||
// and reverse proxy mode is enabled.
|
||||
func GetRequestHost(req *http.Request) string {
|
||||
host := req.Header.Get("X-Forwarded-Host")
|
||||
if !isProxied(req) || host == "" {
|
||||
host = req.Host
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// GetRequestURI return the request host header or X-Forwarded-Uri if present
|
||||
func GetRequestURI(req *http.Request) string {
|
||||
uri := req.Header.Get("X-Forwarded-Uri")
|
||||
if !isProxied(req) || uri == "" {
|
||||
// Use RequestURI to preserve ?query
|
||||
uri = req.URL.RequestURI()
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
func isProxied(req *http.Request) bool {
|
||||
scope := middleware.GetRequestScope(req)
|
||||
return scope.ReverseProxy
|
||||
}
|
||||
|
@ -4,11 +4,9 @@ import (
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -97,42 +95,3 @@ func TestGetCertPool(t *testing.T) {
|
||||
expectedSubjects := []string{testCA1Subj, testCA2Subj}
|
||||
assert.Equal(t, expectedSubjects, got)
|
||||
}
|
||||
|
||||
func TestGetRequestHost(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com", nil)
|
||||
host := GetRequestHost(req)
|
||||
g.Expect(host).To(Equal("example.com"))
|
||||
|
||||
proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil)
|
||||
proxyReq.Header.Add("X-Forwarded-Host", "external.example.com")
|
||||
extHost := GetRequestHost(proxyReq)
|
||||
g.Expect(extHost).To(Equal("external.example.com"))
|
||||
}
|
||||
|
||||
func TestGetRequestProto(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com", nil)
|
||||
proto := GetRequestProto(req)
|
||||
g.Expect(proto).To(Equal("https"))
|
||||
|
||||
proxyReq := httptest.NewRequest("GET", "https://internal.example.com", nil)
|
||||
proxyReq.Header.Add("X-Forwarded-Proto", "http")
|
||||
extProto := GetRequestProto(proxyReq)
|
||||
g.Expect(extProto).To(Equal("http"))
|
||||
}
|
||||
|
||||
func TestGetRequestURI(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/ping", nil)
|
||||
uri := GetRequestURI(req)
|
||||
g.Expect(uri).To(Equal("/ping"))
|
||||
|
||||
proxyReq := httptest.NewRequest("GET", "http://internal.example.com/bong", nil)
|
||||
proxyReq.Header.Add("X-Forwarded-Uri", "/ping")
|
||||
extURI := GetRequestURI(proxyReq)
|
||||
g.Expect(extURI).To(Equal("/ping"))
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user