mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-01-10 04:18:14 +02:00
c1267bb92d
* Add RequestID to the RequestScope * Expose RequestID to auth & request loggers * Use the RequestID in templated HTML pages * Allow customizing the RequestID header * Document new Request ID support * Add more cases to scope/requestID tests * Split Get vs Generate RequestID funtionality * Add {{.RequestID}} to the request logger tests * Move RequestID management to RequestScope * Use HTML escape instead of sanitization for Request ID rendering
128 lines
3.4 KiB
Go
128 lines
3.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
|
|
"github.com/google/uuid"
|
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
const (
|
|
testRequestHeader = "X-Request-Id"
|
|
testRequestID = "11111111-2222-4333-8444-555555555555"
|
|
// mockRand io.Reader below counts bytes from 0-255 in order
|
|
testRandomUUID = "00010203-0405-4607-8809-0a0b0c0d0e0f"
|
|
)
|
|
|
|
var _ = Describe("Scope Suite", func() {
|
|
Context("NewScope", func() {
|
|
var request, nextRequest *http.Request
|
|
var rw http.ResponseWriter
|
|
|
|
BeforeEach(func() {
|
|
var err error
|
|
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
rw = httptest.NewRecorder()
|
|
})
|
|
|
|
Context("ReverseProxy is false", func() {
|
|
BeforeEach(func() {
|
|
handler := NewScope(false, testRequestHeader)(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextRequest = r
|
|
w.WriteHeader(200)
|
|
}))
|
|
handler.ServeHTTP(rw, request)
|
|
})
|
|
|
|
It("does not add a scope to the original request", func() {
|
|
Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil())
|
|
})
|
|
|
|
It("cannot load a scope from the original request using GetRequestScope", func() {
|
|
Expect(middlewareapi.GetRequestScope(request)).To(BeNil())
|
|
})
|
|
|
|
It("adds a scope to the request for the next handler", func() {
|
|
Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil())
|
|
})
|
|
|
|
It("can load a scope from the next handler's request using GetRequestScope", func() {
|
|
scope := middlewareapi.GetRequestScope(nextRequest)
|
|
Expect(scope).ToNot(BeNil())
|
|
Expect(scope.ReverseProxy).To(BeFalse())
|
|
})
|
|
})
|
|
|
|
Context("ReverseProxy is true", func() {
|
|
BeforeEach(func() {
|
|
handler := NewScope(true, testRequestHeader)(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextRequest = r
|
|
w.WriteHeader(200)
|
|
}))
|
|
handler.ServeHTTP(rw, request)
|
|
})
|
|
|
|
It("return a scope where the ReverseProxy field is true", func() {
|
|
scope := middlewareapi.GetRequestScope(nextRequest)
|
|
Expect(scope).ToNot(BeNil())
|
|
Expect(scope.ReverseProxy).To(BeTrue())
|
|
})
|
|
})
|
|
|
|
Context("Request ID header is present", func() {
|
|
BeforeEach(func() {
|
|
request.Header.Add(testRequestHeader, testRequestID)
|
|
handler := NewScope(false, testRequestHeader)(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextRequest = r
|
|
w.WriteHeader(200)
|
|
}))
|
|
handler.ServeHTTP(rw, request)
|
|
})
|
|
|
|
It("sets the RequestID using the request", func() {
|
|
scope := middlewareapi.GetRequestScope(nextRequest)
|
|
Expect(scope.RequestID).To(Equal(testRequestID))
|
|
})
|
|
})
|
|
|
|
Context("Request ID header is missing", func() {
|
|
BeforeEach(func() {
|
|
uuid.SetRand(mockRand{})
|
|
|
|
handler := NewScope(true, testRequestHeader)(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextRequest = r
|
|
w.WriteHeader(200)
|
|
}))
|
|
handler.ServeHTTP(rw, request)
|
|
})
|
|
|
|
AfterEach(func() {
|
|
uuid.SetRand(nil)
|
|
})
|
|
|
|
It("sets the RequestID using a random UUID", func() {
|
|
scope := middlewareapi.GetRequestScope(nextRequest)
|
|
Expect(scope.RequestID).To(Equal(testRandomUUID))
|
|
})
|
|
})
|
|
})
|
|
})
|
|
|
|
type mockRand struct{}
|
|
|
|
func (mockRand) Read(p []byte) (int, error) {
|
|
for i := range p {
|
|
p[i] = byte(i % 256)
|
|
}
|
|
return len(p), nil
|
|
}
|