mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-05-27 23:08:10 +02:00
Move upstream information to request scope
This commit is contained in:
parent
18cd045631
commit
2e72d151e2
@ -11,16 +11,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
|
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
|
||||||
// code and body size
|
// code and body size
|
||||||
type responseLogger struct {
|
type responseLogger struct {
|
||||||
w http.ResponseWriter
|
w http.ResponseWriter
|
||||||
status int
|
status int
|
||||||
size int
|
size int
|
||||||
upstream string
|
|
||||||
authInfo string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Header returns the ResponseWriter's Header
|
// Header returns the ResponseWriter's Header
|
||||||
@ -36,19 +35,17 @@ func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err erro
|
|||||||
return nil, nil, errors.New("http.Hijacker is not available on writer")
|
return nil, nil, errors.New("http.Hijacker is not available on writer")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
|
// extractMetadata extracts metadata from the request/reqsponse for logging
|
||||||
// Header
|
func extractMetadata(rw http.ResponseWriter, req *http.Request) (string, string) {
|
||||||
func (l *responseLogger) ExtractGAPMetadata() {
|
scope := middleware.GetRequestScope(req)
|
||||||
upstream := l.w.Header().Get("GAP-Upstream-Address")
|
upstream := scope.Upstream
|
||||||
if upstream != "" {
|
|
||||||
l.upstream = upstream
|
authInfo := rw.Header().Get("GAP-Auth")
|
||||||
l.w.Header().Del("GAP-Upstream-Address")
|
|
||||||
}
|
|
||||||
authInfo := l.w.Header().Get("GAP-Auth")
|
|
||||||
if authInfo != "" {
|
if authInfo != "" {
|
||||||
l.authInfo = authInfo
|
rw.Header().Del("GAP-Auth")
|
||||||
l.w.Header().Del("GAP-Auth")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return authInfo, upstream
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes the response using the ResponseWriter
|
// Write writes the response using the ResponseWriter
|
||||||
@ -57,7 +54,6 @@ func (l *responseLogger) Write(b []byte) (int, error) {
|
|||||||
// The status will be StatusOK if WriteHeader has not been called yet
|
// The status will be StatusOK if WriteHeader has not been called yet
|
||||||
l.status = http.StatusOK
|
l.status = http.StatusOK
|
||||||
}
|
}
|
||||||
l.ExtractGAPMetadata()
|
|
||||||
size, err := l.w.Write(b)
|
size, err := l.w.Write(b)
|
||||||
l.size += size
|
l.size += size
|
||||||
return size, err
|
return size, err
|
||||||
@ -65,7 +61,6 @@ func (l *responseLogger) Write(b []byte) (int, error) {
|
|||||||
|
|
||||||
// WriteHeader writes the status code for the Response
|
// WriteHeader writes the status code for the Response
|
||||||
func (l *responseLogger) WriteHeader(s int) {
|
func (l *responseLogger) WriteHeader(s int) {
|
||||||
l.ExtractGAPMetadata()
|
|
||||||
l.w.WriteHeader(s)
|
l.w.WriteHeader(s)
|
||||||
l.status = s
|
l.status = s
|
||||||
}
|
}
|
||||||
@ -104,5 +99,7 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||||||
url := *req.URL
|
url := *req.URL
|
||||||
responseLogger := &responseLogger{w: w}
|
responseLogger := &responseLogger{w: w}
|
||||||
h.handler.ServeHTTP(responseLogger, req)
|
h.handler.ServeHTTP(responseLogger, req)
|
||||||
logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size())
|
|
||||||
|
authInfo, upstream := extractMetadata(w, req)
|
||||||
|
logger.PrintReq(authInfo, upstream, req, url, t, responseLogger.Status(), responseLogger.Size())
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,9 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/justinas/alice"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -102,7 +104,7 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
|
|||||||
logger.SetOutput(buf)
|
logger.SetOutput(buf)
|
||||||
logger.SetReqTemplate(test.Format)
|
logger.SetReqTemplate(test.Format)
|
||||||
logger.SetExcludePaths(test.ExcludePaths)
|
logger.SetExcludePaths(test.ExcludePaths)
|
||||||
h := LoggingHandler(http.HandlerFunc(handler))
|
h := alice.New(middleware.NewScope(), LoggingHandler).Then(http.HandlerFunc(handler))
|
||||||
|
|
||||||
r, _ := http.NewRequest("GET", test.Path, nil)
|
r, _ := http.NewRequest("GET", test.Path, nil)
|
||||||
r.RemoteAddr = "127.0.0.1"
|
r.RemoteAddr = "127.0.0.1"
|
||||||
|
@ -21,4 +21,7 @@ type RequestScope struct {
|
|||||||
// SessionRevalidated indicates whether the session has been revalidated since
|
// SessionRevalidated indicates whether the session has been revalidated since
|
||||||
// it was loaded or not.
|
// it was loaded or not.
|
||||||
SessionRevalidated bool
|
SessionRevalidated bool
|
||||||
|
|
||||||
|
// Upstream indicates which (if any) upstream server the request was proxied to.
|
||||||
|
Upstream string
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
const fileScheme = "file"
|
const fileScheme = "file"
|
||||||
@ -37,6 +39,11 @@ type fileServer struct {
|
|||||||
// ServeHTTP proxies requests to the upstream provider while signing the
|
// ServeHTTP proxies requests to the upstream provider while signing the
|
||||||
// request headers
|
// request headers
|
||||||
func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.Header().Set("GAP-Upstream-Address", u.upstream)
|
scope := middleware.GetRequestScope(req)
|
||||||
|
|
||||||
|
// If scope is nil, this will panic.
|
||||||
|
// A scope should always be injected before this handler is called.
|
||||||
|
scope.Upstream = u.upstream
|
||||||
|
|
||||||
u.handler.ServeHTTP(rw, req)
|
u.handler.ServeHTTP(rw, req)
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,9 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/justinas/alice"
|
||||||
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/ginkgo/extensions/table"
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
@ -16,6 +19,7 @@ var _ = Describe("File Server Suite", func() {
|
|||||||
var dir string
|
var dir string
|
||||||
var handler http.Handler
|
var handler http.Handler
|
||||||
var id string
|
var id string
|
||||||
|
var scope *middlewareapi.RequestScope
|
||||||
|
|
||||||
const (
|
const (
|
||||||
foo = "foo"
|
foo = "foo"
|
||||||
@ -25,14 +29,24 @@ var _ = Describe("File Server Suite", func() {
|
|||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
// Generate a random id before each test to check the GAP-Upstream-Address
|
// Generate a random id before each test to check the upstream
|
||||||
// is being set correctly
|
// is being set correctly in the scope
|
||||||
idBytes := make([]byte, 16)
|
idBytes := make([]byte, 16)
|
||||||
_, err := io.ReadFull(rand.Reader, idBytes)
|
_, err := io.ReadFull(rand.Reader, idBytes)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
id = string(idBytes)
|
id = string(idBytes)
|
||||||
|
|
||||||
handler = newFileServer(id, "/files", filesDir)
|
scope = nil
|
||||||
|
// Extract the scope so that we can see that the upstream has been set
|
||||||
|
// correctly
|
||||||
|
extractScope := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
scope = middleware.GetRequestScope(req)
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
handler = alice.New(middleware.NewScope(), extractScope).Then(newFileServer(id, "/files", filesDir))
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
@ -45,7 +59,7 @@ var _ = Describe("File Server Suite", func() {
|
|||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(rw, req)
|
handler.ServeHTTP(rw, req)
|
||||||
|
|
||||||
Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id))
|
Expect(scope.Upstream).To(Equal(id))
|
||||||
Expect(rw.Code).To(Equal(expectedResponseCode))
|
Expect(rw.Code).To(Equal(expectedResponseCode))
|
||||||
Expect(rw.Body.String()).To(Equal(expectedBody))
|
Expect(rw.Body.String()).To(Equal(expectedBody))
|
||||||
},
|
},
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mbland/hmacauth"
|
"github.com/mbland/hmacauth"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
"github.com/yhat/wsutil"
|
"github.com/yhat/wsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -77,7 +78,12 @@ type httpUpstreamProxy struct {
|
|||||||
// ServeHTTP proxies requests to the upstream provider while signing the
|
// ServeHTTP proxies requests to the upstream provider while signing the
|
||||||
// request headers
|
// request headers
|
||||||
func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.Header().Set("GAP-Upstream-Address", h.upstream)
|
scope := middleware.GetRequestScope(req)
|
||||||
|
|
||||||
|
// If scope is nil, this will panic.
|
||||||
|
// A scope should always be injected before this handler is called.
|
||||||
|
scope.Upstream = h.upstream
|
||||||
|
|
||||||
if h.auth != nil {
|
if h.auth != nil {
|
||||||
req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth"))
|
req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth"))
|
||||||
h.auth.SignRequest(req)
|
h.auth.SignRequest(req)
|
||||||
|
@ -13,7 +13,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/justinas/alice"
|
||||||
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/ginkgo/extensions/table"
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
@ -35,6 +38,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
body []byte
|
body []byte
|
||||||
signatureData *options.SignatureData
|
signatureData *options.SignatureData
|
||||||
existingHeaders map[string]string
|
existingHeaders map[string]string
|
||||||
|
expectedUpstream string
|
||||||
expectedResponse testHTTPResponse
|
expectedResponse testHTTPResponse
|
||||||
errorHandler ProxyErrorHandler
|
errorHandler ProxyErrorHandler
|
||||||
}
|
}
|
||||||
@ -66,10 +70,21 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
u, err := url.Parse(*in.serverAddr)
|
u, err := url.Parse(*in.serverAddr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
handler := newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler)
|
var scope *middlewareapi.RequestScope
|
||||||
|
// Extract the scope so that we can see that the upstream has been set
|
||||||
|
// correctly
|
||||||
|
extractScope := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
scope = middleware.GetRequestScope(req)
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := alice.New(middleware.NewScope(), extractScope).Then(newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler))
|
||||||
handler.ServeHTTP(rw, req)
|
handler.ServeHTTP(rw, req)
|
||||||
|
|
||||||
Expect(rw.Code).To(Equal(in.expectedResponse.code))
|
Expect(rw.Code).To(Equal(in.expectedResponse.code))
|
||||||
|
Expect(scope.Upstream).To(Equal(in.expectedUpstream))
|
||||||
|
|
||||||
// Delete extra headers that aren't relevant to tests
|
// Delete extra headers that aren't relevant to tests
|
||||||
testSanitizeResponseHeader(rw.Header())
|
testSanitizeResponseHeader(rw.Header())
|
||||||
@ -88,16 +103,16 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
Expect(request).To(Equal(in.expectedResponse.request))
|
Expect(request).To(Equal(in.expectedResponse.request))
|
||||||
},
|
},
|
||||||
Entry("request a path on the server", &httpUpstreamTableInput{
|
Entry("request a path on the server", &httpUpstreamTableInput{
|
||||||
id: "default",
|
id: "default",
|
||||||
serverAddr: &serverAddr,
|
serverAddr: &serverAddr,
|
||||||
target: "http://example.localhost/foo",
|
target: "http://example.localhost/foo",
|
||||||
method: "GET",
|
method: "GET",
|
||||||
body: []byte{},
|
body: []byte{},
|
||||||
errorHandler: nil,
|
errorHandler: nil,
|
||||||
|
expectedUpstream: "default",
|
||||||
expectedResponse: testHTTPResponse{
|
expectedResponse: testHTTPResponse{
|
||||||
code: 200,
|
code: 200,
|
||||||
header: map[string][]string{
|
header: map[string][]string{
|
||||||
gapUpstream: {"default"},
|
|
||||||
contentType: {applicationJSON},
|
contentType: {applicationJSON},
|
||||||
},
|
},
|
||||||
request: testHTTPRequest{
|
request: testHTTPRequest{
|
||||||
@ -111,16 +126,16 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
Entry("request a path with encoded slashes", &httpUpstreamTableInput{
|
Entry("request a path with encoded slashes", &httpUpstreamTableInput{
|
||||||
id: "encodedSlashes",
|
id: "encodedSlashes",
|
||||||
serverAddr: &serverAddr,
|
serverAddr: &serverAddr,
|
||||||
target: "http://example.localhost/foo%2fbar/?baz=1",
|
target: "http://example.localhost/foo%2fbar/?baz=1",
|
||||||
method: "GET",
|
method: "GET",
|
||||||
body: []byte{},
|
body: []byte{},
|
||||||
errorHandler: nil,
|
errorHandler: nil,
|
||||||
|
expectedUpstream: "encodedSlashes",
|
||||||
expectedResponse: testHTTPResponse{
|
expectedResponse: testHTTPResponse{
|
||||||
code: 200,
|
code: 200,
|
||||||
header: map[string][]string{
|
header: map[string][]string{
|
||||||
gapUpstream: {"encodedSlashes"},
|
|
||||||
contentType: {applicationJSON},
|
contentType: {applicationJSON},
|
||||||
},
|
},
|
||||||
request: testHTTPRequest{
|
request: testHTTPRequest{
|
||||||
@ -134,16 +149,16 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
Entry("when the request has a body", &httpUpstreamTableInput{
|
Entry("when the request has a body", &httpUpstreamTableInput{
|
||||||
id: "requestWithBody",
|
id: "requestWithBody",
|
||||||
serverAddr: &serverAddr,
|
serverAddr: &serverAddr,
|
||||||
target: "http://example.localhost/withBody",
|
target: "http://example.localhost/withBody",
|
||||||
method: "POST",
|
method: "POST",
|
||||||
body: []byte("body"),
|
body: []byte("body"),
|
||||||
errorHandler: nil,
|
errorHandler: nil,
|
||||||
|
expectedUpstream: "requestWithBody",
|
||||||
expectedResponse: testHTTPResponse{
|
expectedResponse: testHTTPResponse{
|
||||||
code: 200,
|
code: 200,
|
||||||
header: map[string][]string{
|
header: map[string][]string{
|
||||||
gapUpstream: {"requestWithBody"},
|
|
||||||
contentType: {applicationJSON},
|
contentType: {applicationJSON},
|
||||||
},
|
},
|
||||||
request: testHTTPRequest{
|
request: testHTTPRequest{
|
||||||
@ -159,17 +174,16 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
Entry("when the upstream is unavailable", &httpUpstreamTableInput{
|
Entry("when the upstream is unavailable", &httpUpstreamTableInput{
|
||||||
id: "unavailableUpstream",
|
id: "unavailableUpstream",
|
||||||
serverAddr: &invalidServer,
|
serverAddr: &invalidServer,
|
||||||
target: "http://example.localhost/unavailableUpstream",
|
target: "http://example.localhost/unavailableUpstream",
|
||||||
method: "GET",
|
method: "GET",
|
||||||
body: []byte{},
|
body: []byte{},
|
||||||
errorHandler: nil,
|
errorHandler: nil,
|
||||||
|
expectedUpstream: "unavailableUpstream",
|
||||||
expectedResponse: testHTTPResponse{
|
expectedResponse: testHTTPResponse{
|
||||||
code: 502,
|
code: 502,
|
||||||
header: map[string][]string{
|
header: map[string][]string{},
|
||||||
gapUpstream: {"unavailableUpstream"},
|
|
||||||
},
|
|
||||||
request: testHTTPRequest{},
|
request: testHTTPRequest{},
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
@ -183,11 +197,10 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
rw.WriteHeader(502)
|
rw.WriteHeader(502)
|
||||||
rw.Write([]byte("error"))
|
rw.Write([]byte("error"))
|
||||||
},
|
},
|
||||||
|
expectedUpstream: "withErrorHandler",
|
||||||
expectedResponse: testHTTPResponse{
|
expectedResponse: testHTTPResponse{
|
||||||
code: 502,
|
code: 502,
|
||||||
header: map[string][]string{
|
header: map[string][]string{},
|
||||||
gapUpstream: {"withErrorHandler"},
|
|
||||||
},
|
|
||||||
raw: "error",
|
raw: "error",
|
||||||
request: testHTTPRequest{},
|
request: testHTTPRequest{},
|
||||||
},
|
},
|
||||||
@ -202,12 +215,12 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
Hash: crypto.SHA256,
|
Hash: crypto.SHA256,
|
||||||
Key: "key",
|
Key: "key",
|
||||||
},
|
},
|
||||||
errorHandler: nil,
|
errorHandler: nil,
|
||||||
|
expectedUpstream: "withSignature",
|
||||||
expectedResponse: testHTTPResponse{
|
expectedResponse: testHTTPResponse{
|
||||||
code: 200,
|
code: 200,
|
||||||
header: map[string][]string{
|
header: map[string][]string{
|
||||||
contentType: {applicationJSON},
|
contentType: {applicationJSON},
|
||||||
gapUpstream: {"withSignature"},
|
|
||||||
},
|
},
|
||||||
request: testHTTPRequest{
|
request: testHTTPRequest{
|
||||||
Method: "GET",
|
Method: "GET",
|
||||||
@ -223,12 +236,13 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
Entry("with existing headers", &httpUpstreamTableInput{
|
Entry("with existing headers", &httpUpstreamTableInput{
|
||||||
id: "existingHeaders",
|
id: "existingHeaders",
|
||||||
serverAddr: &serverAddr,
|
serverAddr: &serverAddr,
|
||||||
target: "http://example.localhost/existingHeaders",
|
target: "http://example.localhost/existingHeaders",
|
||||||
method: "GET",
|
method: "GET",
|
||||||
body: []byte{},
|
body: []byte{},
|
||||||
errorHandler: nil,
|
errorHandler: nil,
|
||||||
|
expectedUpstream: "existingHeaders",
|
||||||
existingHeaders: map[string]string{
|
existingHeaders: map[string]string{
|
||||||
"Header1": "value1",
|
"Header1": "value1",
|
||||||
"Header2": "value2",
|
"Header2": "value2",
|
||||||
@ -236,7 +250,6 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
expectedResponse: testHTTPResponse{
|
expectedResponse: testHTTPResponse{
|
||||||
code: 200,
|
code: 200,
|
||||||
header: map[string][]string{
|
header: map[string][]string{
|
||||||
gapUpstream: {"existingHeaders"},
|
|
||||||
contentType: {applicationJSON},
|
contentType: {applicationJSON},
|
||||||
},
|
},
|
||||||
request: testHTTPRequest{
|
request: testHTTPRequest{
|
||||||
@ -274,18 +287,21 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
httpUpstream, ok := handler.(*httpUpstreamProxy)
|
httpUpstream, ok := handler.(*httpUpstreamProxy)
|
||||||
Expect(ok).To(BeTrue())
|
Expect(ok).To(BeTrue())
|
||||||
|
|
||||||
|
var gotRequest *http.Request
|
||||||
// Override the handler to just run the director and not actually send the request
|
// Override the handler to just run the director and not actually send the request
|
||||||
requestInterceptor := func(h http.Handler) http.Handler {
|
requestInterceptor := func(h http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
|
||||||
proxy, ok := h.(*httputil.ReverseProxy)
|
proxy, ok := h.(*httputil.ReverseProxy)
|
||||||
Expect(ok).To(BeTrue())
|
Expect(ok).To(BeTrue())
|
||||||
proxy.Director(req)
|
proxy.Director(req)
|
||||||
|
|
||||||
|
gotRequest = req
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
httpUpstream.handler = requestInterceptor(httpUpstream.handler)
|
httpUpstream.handler = requestInterceptor(httpUpstream.handler)
|
||||||
|
|
||||||
httpUpstream.ServeHTTP(rw, req)
|
alice.New(middleware.NewScope()).Then(httpUpstream).ServeHTTP(rw, req)
|
||||||
Expect(req.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://")))
|
Expect(gotRequest.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://")))
|
||||||
})
|
})
|
||||||
|
|
||||||
type newUpstreamTableInput struct {
|
type newUpstreamTableInput struct {
|
||||||
@ -368,6 +384,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
|
|
||||||
Context("with a websocket proxy", func() {
|
Context("with a websocket proxy", func() {
|
||||||
var proxyServer *httptest.Server
|
var proxyServer *httptest.Server
|
||||||
|
var scope *middlewareapi.RequestScope
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
flush := 1 * time.Second
|
flush := 1 * time.Second
|
||||||
@ -382,7 +399,17 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
u, err := url.Parse(serverAddr)
|
u, err := url.Parse(serverAddr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
handler := newHTTPUpstreamProxy(upstream, u, nil, nil)
|
scope = nil
|
||||||
|
// Extract the scope so that we can see that the upstream has been set
|
||||||
|
// correctly
|
||||||
|
extractScope := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
scope = middleware.GetRequestScope(req)
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := alice.New(middleware.NewScope(), extractScope).Then(newHTTPUpstreamProxy(upstream, u, nil, nil))
|
||||||
proxyServer = httptest.NewServer(handler)
|
proxyServer = httptest.NewServer(handler)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -414,7 +441,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
|||||||
response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String()))
|
response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String()))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(response.StatusCode).To(Equal(200))
|
Expect(response.StatusCode).To(Equal(200))
|
||||||
Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy"))
|
Expect(scope.Upstream).To(Equal("websocketProxy"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -8,7 +8,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
||||||
|
"github.com/justinas/alice"
|
||||||
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/ginkgo/extensions/table"
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
@ -16,6 +19,7 @@ import (
|
|||||||
|
|
||||||
var _ = Describe("Proxy Suite", func() {
|
var _ = Describe("Proxy Suite", func() {
|
||||||
var upstreamServer http.Handler
|
var upstreamServer http.Handler
|
||||||
|
var scope *middlewareapi.RequestScope
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"}
|
sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"}
|
||||||
@ -56,12 +60,25 @@ var _ = Describe("Proxy Suite", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamServer, err = NewProxy(upstreams, sigData, errorHandler)
|
proxyServer, err := NewProxy(upstreams, sigData, errorHandler)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
scope = nil
|
||||||
|
// Extract the scope so that we can see that the upstream has been set
|
||||||
|
// correctly
|
||||||
|
extractScope := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
scope = middleware.GetRequestScope(req)
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamServer = alice.New(middleware.NewScope(), extractScope).Then(proxyServer)
|
||||||
})
|
})
|
||||||
|
|
||||||
type proxyTableInput struct {
|
type proxyTableInput struct {
|
||||||
target string
|
target string
|
||||||
|
upstream string
|
||||||
response testHTTPResponse
|
response testHTTPResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,6 +92,7 @@ var _ = Describe("Proxy Suite", func() {
|
|||||||
upstreamServer.ServeHTTP(rw, req)
|
upstreamServer.ServeHTTP(rw, req)
|
||||||
|
|
||||||
Expect(rw.Code).To(Equal(in.response.code))
|
Expect(rw.Code).To(Equal(in.response.code))
|
||||||
|
Expect(scope.Upstream).To(Equal(in.upstream))
|
||||||
|
|
||||||
// Delete extra headers that aren't relevant to tests
|
// Delete extra headers that aren't relevant to tests
|
||||||
testSanitizeResponseHeader(rw.Header())
|
testSanitizeResponseHeader(rw.Header())
|
||||||
@ -94,11 +112,11 @@ var _ = Describe("Proxy Suite", func() {
|
|||||||
Expect(request).To(Equal(in.response.request))
|
Expect(request).To(Equal(in.response.request))
|
||||||
},
|
},
|
||||||
Entry("with a request to the HTTP service", &proxyTableInput{
|
Entry("with a request to the HTTP service", &proxyTableInput{
|
||||||
target: "http://example.localhost/http/1234",
|
target: "http://example.localhost/http/1234",
|
||||||
|
upstream: "http-backend",
|
||||||
response: testHTTPResponse{
|
response: testHTTPResponse{
|
||||||
code: 200,
|
code: 200,
|
||||||
header: map[string][]string{
|
header: map[string][]string{
|
||||||
gapUpstream: {"http-backend"},
|
|
||||||
contentType: {applicationJSON},
|
contentType: {applicationJSON},
|
||||||
},
|
},
|
||||||
request: testHTTPRequest{
|
request: testHTTPRequest{
|
||||||
@ -115,33 +133,31 @@ var _ = Describe("Proxy Suite", func() {
|
|||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
Entry("with a request to the File backend", &proxyTableInput{
|
Entry("with a request to the File backend", &proxyTableInput{
|
||||||
target: "http://example.localhost/files/foo",
|
target: "http://example.localhost/files/foo",
|
||||||
|
upstream: "file-backend",
|
||||||
response: testHTTPResponse{
|
response: testHTTPResponse{
|
||||||
code: 200,
|
code: 200,
|
||||||
header: map[string][]string{
|
header: map[string][]string{
|
||||||
contentType: {textPlainUTF8},
|
contentType: {textPlainUTF8},
|
||||||
gapUpstream: {"file-backend"},
|
|
||||||
},
|
},
|
||||||
raw: "foo",
|
raw: "foo",
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
Entry("with a request to the Static backend", &proxyTableInput{
|
Entry("with a request to the Static backend", &proxyTableInput{
|
||||||
target: "http://example.localhost/static/bar",
|
target: "http://example.localhost/static/bar",
|
||||||
|
upstream: "static-backend",
|
||||||
response: testHTTPResponse{
|
response: testHTTPResponse{
|
||||||
code: 200,
|
code: 200,
|
||||||
header: map[string][]string{
|
header: map[string][]string{},
|
||||||
gapUpstream: {"static-backend"},
|
raw: "Authenticated",
|
||||||
},
|
|
||||||
raw: "Authenticated",
|
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
Entry("with a request to the bad HTTP backend", &proxyTableInput{
|
Entry("with a request to the bad HTTP backend", &proxyTableInput{
|
||||||
target: "http://example.localhost/bad-http/bad",
|
target: "http://example.localhost/bad-http/bad",
|
||||||
|
upstream: "bad-http-backend",
|
||||||
response: testHTTPResponse{
|
response: testHTTPResponse{
|
||||||
code: 502,
|
code: 502,
|
||||||
header: map[string][]string{
|
header: map[string][]string{},
|
||||||
gapUpstream: {"bad-http-backend"},
|
|
||||||
},
|
|
||||||
// This tests the error handler
|
// This tests the error handler
|
||||||
raw: "Bad Gateway\nError proxying to upstream server\nprefix",
|
raw: "Bad Gateway\nError proxying to upstream server\nprefix",
|
||||||
},
|
},
|
||||||
@ -158,13 +174,12 @@ var _ = Describe("Proxy Suite", func() {
|
|||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
Entry("with a request to the to backend registered to a single path", &proxyTableInput{
|
Entry("with a request to the to backend registered to a single path", &proxyTableInput{
|
||||||
target: "http://example.localhost/single-path",
|
target: "http://example.localhost/single-path",
|
||||||
|
upstream: "single-path-backend",
|
||||||
response: testHTTPResponse{
|
response: testHTTPResponse{
|
||||||
code: 200,
|
code: 200,
|
||||||
header: map[string][]string{
|
header: map[string][]string{},
|
||||||
gapUpstream: {"single-path-backend"},
|
raw: "Authenticated",
|
||||||
},
|
|
||||||
raw: "Authenticated",
|
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{
|
Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{
|
||||||
|
@ -3,6 +3,8 @@ package upstream
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultStaticResponseCode = 200
|
const defaultStaticResponseCode = 200
|
||||||
@ -24,7 +26,12 @@ type staticResponseHandler struct {
|
|||||||
|
|
||||||
// ServeHTTP serves a static response.
|
// ServeHTTP serves a static response.
|
||||||
func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.Header().Set("GAP-Upstream-Address", s.upstream)
|
scope := middleware.GetRequestScope(req)
|
||||||
|
|
||||||
|
// If scope is nil, this will panic.
|
||||||
|
// A scope should always be injected before this handler is called.
|
||||||
|
scope.Upstream = s.upstream
|
||||||
|
|
||||||
rw.WriteHeader(s.code)
|
rw.WriteHeader(s.code)
|
||||||
fmt.Fprintf(rw, "Authenticated")
|
fmt.Fprintf(rw, "Authenticated")
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
||||||
|
"github.com/justinas/alice"
|
||||||
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/ginkgo/extensions/table"
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
@ -16,8 +19,8 @@ var _ = Describe("Static Response Suite", func() {
|
|||||||
var id string
|
var id string
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
// Generate a random id before each test to check the GAP-Upstream-Address
|
// Generate a random id before each test to check the upstream
|
||||||
// is being set correctly
|
// is being set correctly in the scope
|
||||||
idBytes := make([]byte, 16)
|
idBytes := make([]byte, 16)
|
||||||
_, err := io.ReadFull(rand.Reader, idBytes)
|
_, err := io.ReadFull(rand.Reader, idBytes)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@ -37,13 +40,24 @@ var _ = Describe("Static Response Suite", func() {
|
|||||||
if in.staticCode != 0 {
|
if in.staticCode != 0 {
|
||||||
code = &in.staticCode
|
code = &in.staticCode
|
||||||
}
|
}
|
||||||
handler := newStaticResponseHandler(id, code)
|
|
||||||
|
var scope *middlewareapi.RequestScope
|
||||||
|
// Extract the scope so that we can see that the upstream has been set
|
||||||
|
// correctly
|
||||||
|
extractScope := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
scope = middleware.GetRequestScope(req)
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := alice.New(middleware.NewScope(), extractScope).Then(newStaticResponseHandler(id, code))
|
||||||
|
|
||||||
req := httptest.NewRequest("", in.requestPath, nil)
|
req := httptest.NewRequest("", in.requestPath, nil)
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(rw, req)
|
handler.ServeHTTP(rw, req)
|
||||||
|
|
||||||
Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id))
|
Expect(scope.Upstream).To(Equal(id))
|
||||||
Expect(rw.Code).To(Equal(in.expectedCode))
|
Expect(rw.Code).To(Equal(in.expectedCode))
|
||||||
Expect(rw.Body.String()).To(Equal(in.expectedBody))
|
Expect(rw.Body.String()).To(Equal(in.expectedBody))
|
||||||
},
|
},
|
||||||
|
@ -58,7 +58,6 @@ const (
|
|||||||
acceptEncoding = "Accept-Encoding"
|
acceptEncoding = "Accept-Encoding"
|
||||||
applicationJSON = "application/json"
|
applicationJSON = "application/json"
|
||||||
textPlainUTF8 = "text/plain; charset=utf-8"
|
textPlainUTF8 = "text/plain; charset=utf-8"
|
||||||
gapUpstream = "Gap-Upstream-Address"
|
|
||||||
gapAuth = "Gap-Auth"
|
gapAuth = "Gap-Auth"
|
||||||
gapSignature = "Gap-Signature"
|
gapSignature = "Gap-Signature"
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user