1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-05-27 23:08:10 +02:00

Move upstream information to request scope

This commit is contained in:
Joel Speed 2020-10-04 16:23:38 +01:00
parent 18cd045631
commit 2e72d151e2
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
11 changed files with 194 additions and 103 deletions

View File

@ -11,16 +11,15 @@ import (
"time" "time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
) )
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
// code and body size // code and body size
type responseLogger struct { type responseLogger struct {
w http.ResponseWriter w http.ResponseWriter
status int status int
size int size int
upstream string
authInfo string
} }
// Header returns the ResponseWriter's Header // Header returns the ResponseWriter's Header
@ -36,19 +35,17 @@ func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err erro
return nil, nil, errors.New("http.Hijacker is not available on writer") return nil, nil, errors.New("http.Hijacker is not available on writer")
} }
// ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's // extractMetadata extracts metadata from the request/reqsponse for logging
// Header func extractMetadata(rw http.ResponseWriter, req *http.Request) (string, string) {
func (l *responseLogger) ExtractGAPMetadata() { scope := middleware.GetRequestScope(req)
upstream := l.w.Header().Get("GAP-Upstream-Address") upstream := scope.Upstream
if upstream != "" {
l.upstream = upstream authInfo := rw.Header().Get("GAP-Auth")
l.w.Header().Del("GAP-Upstream-Address")
}
authInfo := l.w.Header().Get("GAP-Auth")
if authInfo != "" { if authInfo != "" {
l.authInfo = authInfo rw.Header().Del("GAP-Auth")
l.w.Header().Del("GAP-Auth")
} }
return authInfo, upstream
} }
// Write writes the response using the ResponseWriter // Write writes the response using the ResponseWriter
@ -57,7 +54,6 @@ func (l *responseLogger) Write(b []byte) (int, error) {
// The status will be StatusOK if WriteHeader has not been called yet // The status will be StatusOK if WriteHeader has not been called yet
l.status = http.StatusOK l.status = http.StatusOK
} }
l.ExtractGAPMetadata()
size, err := l.w.Write(b) size, err := l.w.Write(b)
l.size += size l.size += size
return size, err return size, err
@ -65,7 +61,6 @@ func (l *responseLogger) Write(b []byte) (int, error) {
// WriteHeader writes the status code for the Response // WriteHeader writes the status code for the Response
func (l *responseLogger) WriteHeader(s int) { func (l *responseLogger) WriteHeader(s int) {
l.ExtractGAPMetadata()
l.w.WriteHeader(s) l.w.WriteHeader(s)
l.status = s l.status = s
} }
@ -104,5 +99,7 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
url := *req.URL url := *req.URL
responseLogger := &responseLogger{w: w} responseLogger := &responseLogger{w: w}
h.handler.ServeHTTP(responseLogger, req) h.handler.ServeHTTP(responseLogger, req)
logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size())
authInfo, upstream := extractMetadata(w, req)
logger.PrintReq(authInfo, upstream, req, url, t, responseLogger.Status(), responseLogger.Size())
} }

View File

@ -6,7 +6,9 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/justinas/alice"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -102,7 +104,7 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
logger.SetOutput(buf) logger.SetOutput(buf)
logger.SetReqTemplate(test.Format) logger.SetReqTemplate(test.Format)
logger.SetExcludePaths(test.ExcludePaths) logger.SetExcludePaths(test.ExcludePaths)
h := LoggingHandler(http.HandlerFunc(handler)) h := alice.New(middleware.NewScope(), LoggingHandler).Then(http.HandlerFunc(handler))
r, _ := http.NewRequest("GET", test.Path, nil) r, _ := http.NewRequest("GET", test.Path, nil)
r.RemoteAddr = "127.0.0.1" r.RemoteAddr = "127.0.0.1"

View File

@ -21,4 +21,7 @@ type RequestScope struct {
// SessionRevalidated indicates whether the session has been revalidated since // SessionRevalidated indicates whether the session has been revalidated since
// it was loaded or not. // it was loaded or not.
SessionRevalidated bool SessionRevalidated bool
// Upstream indicates which (if any) upstream server the request was proxied to.
Upstream string
} }

View File

@ -4,6 +4,8 @@ import (
"net/http" "net/http"
"runtime" "runtime"
"strings" "strings"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
) )
const fileScheme = "file" const fileScheme = "file"
@ -37,6 +39,11 @@ type fileServer struct {
// ServeHTTP proxies requests to the upstream provider while signing the // ServeHTTP proxies requests to the upstream provider while signing the
// request headers // request headers
func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("GAP-Upstream-Address", u.upstream) scope := middleware.GetRequestScope(req)
// If scope is nil, this will panic.
// A scope should always be injected before this handler is called.
scope.Upstream = u.upstream
u.handler.ServeHTTP(rw, req) u.handler.ServeHTTP(rw, req)
} }

View File

@ -7,6 +7,9 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -16,6 +19,7 @@ var _ = Describe("File Server Suite", func() {
var dir string var dir string
var handler http.Handler var handler http.Handler
var id string var id string
var scope *middlewareapi.RequestScope
const ( const (
foo = "foo" foo = "foo"
@ -25,14 +29,24 @@ var _ = Describe("File Server Suite", func() {
) )
BeforeEach(func() { BeforeEach(func() {
// Generate a random id before each test to check the GAP-Upstream-Address // Generate a random id before each test to check the upstream
// is being set correctly // is being set correctly in the scope
idBytes := make([]byte, 16) idBytes := make([]byte, 16)
_, err := io.ReadFull(rand.Reader, idBytes) _, err := io.ReadFull(rand.Reader, idBytes)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
id = string(idBytes) id = string(idBytes)
handler = newFileServer(id, "/files", filesDir) scope = nil
// Extract the scope so that we can see that the upstream has been set
// correctly
extractScope := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope = middleware.GetRequestScope(req)
next.ServeHTTP(rw, req)
})
}
handler = alice.New(middleware.NewScope(), extractScope).Then(newFileServer(id, "/files", filesDir))
}) })
AfterEach(func() { AfterEach(func() {
@ -45,7 +59,7 @@ var _ = Describe("File Server Suite", func() {
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req) handler.ServeHTTP(rw, req)
Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) Expect(scope.Upstream).To(Equal(id))
Expect(rw.Code).To(Equal(expectedResponseCode)) Expect(rw.Code).To(Equal(expectedResponseCode))
Expect(rw.Body.String()).To(Equal(expectedBody)) Expect(rw.Body.String()).To(Equal(expectedBody))
}, },

View File

@ -10,6 +10,7 @@ import (
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
"github.com/yhat/wsutil" "github.com/yhat/wsutil"
) )
@ -77,7 +78,12 @@ type httpUpstreamProxy struct {
// ServeHTTP proxies requests to the upstream provider while signing the // ServeHTTP proxies requests to the upstream provider while signing the
// request headers // request headers
func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("GAP-Upstream-Address", h.upstream) scope := middleware.GetRequestScope(req)
// If scope is nil, this will panic.
// A scope should always be injected before this handler is called.
scope.Upstream = h.upstream
if h.auth != nil { if h.auth != nil {
req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth")) req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth"))
h.auth.SignRequest(req) h.auth.SignRequest(req)

View File

@ -13,7 +13,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -35,6 +38,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
body []byte body []byte
signatureData *options.SignatureData signatureData *options.SignatureData
existingHeaders map[string]string existingHeaders map[string]string
expectedUpstream string
expectedResponse testHTTPResponse expectedResponse testHTTPResponse
errorHandler ProxyErrorHandler errorHandler ProxyErrorHandler
} }
@ -66,10 +70,21 @@ var _ = Describe("HTTP Upstream Suite", func() {
u, err := url.Parse(*in.serverAddr) u, err := url.Parse(*in.serverAddr)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
handler := newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler) var scope *middlewareapi.RequestScope
// Extract the scope so that we can see that the upstream has been set
// correctly
extractScope := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope = middleware.GetRequestScope(req)
next.ServeHTTP(rw, req)
})
}
handler := alice.New(middleware.NewScope(), extractScope).Then(newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler))
handler.ServeHTTP(rw, req) handler.ServeHTTP(rw, req)
Expect(rw.Code).To(Equal(in.expectedResponse.code)) Expect(rw.Code).To(Equal(in.expectedResponse.code))
Expect(scope.Upstream).To(Equal(in.expectedUpstream))
// Delete extra headers that aren't relevant to tests // Delete extra headers that aren't relevant to tests
testSanitizeResponseHeader(rw.Header()) testSanitizeResponseHeader(rw.Header())
@ -88,16 +103,16 @@ var _ = Describe("HTTP Upstream Suite", func() {
Expect(request).To(Equal(in.expectedResponse.request)) Expect(request).To(Equal(in.expectedResponse.request))
}, },
Entry("request a path on the server", &httpUpstreamTableInput{ Entry("request a path on the server", &httpUpstreamTableInput{
id: "default", id: "default",
serverAddr: &serverAddr, serverAddr: &serverAddr,
target: "http://example.localhost/foo", target: "http://example.localhost/foo",
method: "GET", method: "GET",
body: []byte{}, body: []byte{},
errorHandler: nil, errorHandler: nil,
expectedUpstream: "default",
expectedResponse: testHTTPResponse{ expectedResponse: testHTTPResponse{
code: 200, code: 200,
header: map[string][]string{ header: map[string][]string{
gapUpstream: {"default"},
contentType: {applicationJSON}, contentType: {applicationJSON},
}, },
request: testHTTPRequest{ request: testHTTPRequest{
@ -111,16 +126,16 @@ var _ = Describe("HTTP Upstream Suite", func() {
}, },
}), }),
Entry("request a path with encoded slashes", &httpUpstreamTableInput{ Entry("request a path with encoded slashes", &httpUpstreamTableInput{
id: "encodedSlashes", id: "encodedSlashes",
serverAddr: &serverAddr, serverAddr: &serverAddr,
target: "http://example.localhost/foo%2fbar/?baz=1", target: "http://example.localhost/foo%2fbar/?baz=1",
method: "GET", method: "GET",
body: []byte{}, body: []byte{},
errorHandler: nil, errorHandler: nil,
expectedUpstream: "encodedSlashes",
expectedResponse: testHTTPResponse{ expectedResponse: testHTTPResponse{
code: 200, code: 200,
header: map[string][]string{ header: map[string][]string{
gapUpstream: {"encodedSlashes"},
contentType: {applicationJSON}, contentType: {applicationJSON},
}, },
request: testHTTPRequest{ request: testHTTPRequest{
@ -134,16 +149,16 @@ var _ = Describe("HTTP Upstream Suite", func() {
}, },
}), }),
Entry("when the request has a body", &httpUpstreamTableInput{ Entry("when the request has a body", &httpUpstreamTableInput{
id: "requestWithBody", id: "requestWithBody",
serverAddr: &serverAddr, serverAddr: &serverAddr,
target: "http://example.localhost/withBody", target: "http://example.localhost/withBody",
method: "POST", method: "POST",
body: []byte("body"), body: []byte("body"),
errorHandler: nil, errorHandler: nil,
expectedUpstream: "requestWithBody",
expectedResponse: testHTTPResponse{ expectedResponse: testHTTPResponse{
code: 200, code: 200,
header: map[string][]string{ header: map[string][]string{
gapUpstream: {"requestWithBody"},
contentType: {applicationJSON}, contentType: {applicationJSON},
}, },
request: testHTTPRequest{ request: testHTTPRequest{
@ -159,17 +174,16 @@ var _ = Describe("HTTP Upstream Suite", func() {
}, },
}), }),
Entry("when the upstream is unavailable", &httpUpstreamTableInput{ Entry("when the upstream is unavailable", &httpUpstreamTableInput{
id: "unavailableUpstream", id: "unavailableUpstream",
serverAddr: &invalidServer, serverAddr: &invalidServer,
target: "http://example.localhost/unavailableUpstream", target: "http://example.localhost/unavailableUpstream",
method: "GET", method: "GET",
body: []byte{}, body: []byte{},
errorHandler: nil, errorHandler: nil,
expectedUpstream: "unavailableUpstream",
expectedResponse: testHTTPResponse{ expectedResponse: testHTTPResponse{
code: 502, code: 502,
header: map[string][]string{ header: map[string][]string{},
gapUpstream: {"unavailableUpstream"},
},
request: testHTTPRequest{}, request: testHTTPRequest{},
}, },
}), }),
@ -183,11 +197,10 @@ var _ = Describe("HTTP Upstream Suite", func() {
rw.WriteHeader(502) rw.WriteHeader(502)
rw.Write([]byte("error")) rw.Write([]byte("error"))
}, },
expectedUpstream: "withErrorHandler",
expectedResponse: testHTTPResponse{ expectedResponse: testHTTPResponse{
code: 502, code: 502,
header: map[string][]string{ header: map[string][]string{},
gapUpstream: {"withErrorHandler"},
},
raw: "error", raw: "error",
request: testHTTPRequest{}, request: testHTTPRequest{},
}, },
@ -202,12 +215,12 @@ var _ = Describe("HTTP Upstream Suite", func() {
Hash: crypto.SHA256, Hash: crypto.SHA256,
Key: "key", Key: "key",
}, },
errorHandler: nil, errorHandler: nil,
expectedUpstream: "withSignature",
expectedResponse: testHTTPResponse{ expectedResponse: testHTTPResponse{
code: 200, code: 200,
header: map[string][]string{ header: map[string][]string{
contentType: {applicationJSON}, contentType: {applicationJSON},
gapUpstream: {"withSignature"},
}, },
request: testHTTPRequest{ request: testHTTPRequest{
Method: "GET", Method: "GET",
@ -223,12 +236,13 @@ var _ = Describe("HTTP Upstream Suite", func() {
}, },
}), }),
Entry("with existing headers", &httpUpstreamTableInput{ Entry("with existing headers", &httpUpstreamTableInput{
id: "existingHeaders", id: "existingHeaders",
serverAddr: &serverAddr, serverAddr: &serverAddr,
target: "http://example.localhost/existingHeaders", target: "http://example.localhost/existingHeaders",
method: "GET", method: "GET",
body: []byte{}, body: []byte{},
errorHandler: nil, errorHandler: nil,
expectedUpstream: "existingHeaders",
existingHeaders: map[string]string{ existingHeaders: map[string]string{
"Header1": "value1", "Header1": "value1",
"Header2": "value2", "Header2": "value2",
@ -236,7 +250,6 @@ var _ = Describe("HTTP Upstream Suite", func() {
expectedResponse: testHTTPResponse{ expectedResponse: testHTTPResponse{
code: 200, code: 200,
header: map[string][]string{ header: map[string][]string{
gapUpstream: {"existingHeaders"},
contentType: {applicationJSON}, contentType: {applicationJSON},
}, },
request: testHTTPRequest{ request: testHTTPRequest{
@ -274,18 +287,21 @@ var _ = Describe("HTTP Upstream Suite", func() {
httpUpstream, ok := handler.(*httpUpstreamProxy) httpUpstream, ok := handler.(*httpUpstreamProxy)
Expect(ok).To(BeTrue()) Expect(ok).To(BeTrue())
var gotRequest *http.Request
// Override the handler to just run the director and not actually send the request // Override the handler to just run the director and not actually send the request
requestInterceptor := func(h http.Handler) http.Handler { requestInterceptor := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
proxy, ok := h.(*httputil.ReverseProxy) proxy, ok := h.(*httputil.ReverseProxy)
Expect(ok).To(BeTrue()) Expect(ok).To(BeTrue())
proxy.Director(req) proxy.Director(req)
gotRequest = req
}) })
} }
httpUpstream.handler = requestInterceptor(httpUpstream.handler) httpUpstream.handler = requestInterceptor(httpUpstream.handler)
httpUpstream.ServeHTTP(rw, req) alice.New(middleware.NewScope()).Then(httpUpstream).ServeHTTP(rw, req)
Expect(req.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://"))) Expect(gotRequest.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://")))
}) })
type newUpstreamTableInput struct { type newUpstreamTableInput struct {
@ -368,6 +384,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
Context("with a websocket proxy", func() { Context("with a websocket proxy", func() {
var proxyServer *httptest.Server var proxyServer *httptest.Server
var scope *middlewareapi.RequestScope
BeforeEach(func() { BeforeEach(func() {
flush := 1 * time.Second flush := 1 * time.Second
@ -382,7 +399,17 @@ var _ = Describe("HTTP Upstream Suite", func() {
u, err := url.Parse(serverAddr) u, err := url.Parse(serverAddr)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
handler := newHTTPUpstreamProxy(upstream, u, nil, nil) scope = nil
// Extract the scope so that we can see that the upstream has been set
// correctly
extractScope := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope = middleware.GetRequestScope(req)
next.ServeHTTP(rw, req)
})
}
handler := alice.New(middleware.NewScope(), extractScope).Then(newHTTPUpstreamProxy(upstream, u, nil, nil))
proxyServer = httptest.NewServer(handler) proxyServer = httptest.NewServer(handler)
}) })
@ -414,7 +441,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String()))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(response.StatusCode).To(Equal(200)) Expect(response.StatusCode).To(Equal(200))
Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy")) Expect(scope.Upstream).To(Equal("websocketProxy"))
}) })
}) })
}) })

View File

@ -8,7 +8,10 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -16,6 +19,7 @@ import (
var _ = Describe("Proxy Suite", func() { var _ = Describe("Proxy Suite", func() {
var upstreamServer http.Handler var upstreamServer http.Handler
var scope *middlewareapi.RequestScope
BeforeEach(func() { BeforeEach(func() {
sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"}
@ -56,12 +60,25 @@ var _ = Describe("Proxy Suite", func() {
}, },
} }
upstreamServer, err = NewProxy(upstreams, sigData, errorHandler) proxyServer, err := NewProxy(upstreams, sigData, errorHandler)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
scope = nil
// Extract the scope so that we can see that the upstream has been set
// correctly
extractScope := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope = middleware.GetRequestScope(req)
next.ServeHTTP(rw, req)
})
}
upstreamServer = alice.New(middleware.NewScope(), extractScope).Then(proxyServer)
}) })
type proxyTableInput struct { type proxyTableInput struct {
target string target string
upstream string
response testHTTPResponse response testHTTPResponse
} }
@ -75,6 +92,7 @@ var _ = Describe("Proxy Suite", func() {
upstreamServer.ServeHTTP(rw, req) upstreamServer.ServeHTTP(rw, req)
Expect(rw.Code).To(Equal(in.response.code)) Expect(rw.Code).To(Equal(in.response.code))
Expect(scope.Upstream).To(Equal(in.upstream))
// Delete extra headers that aren't relevant to tests // Delete extra headers that aren't relevant to tests
testSanitizeResponseHeader(rw.Header()) testSanitizeResponseHeader(rw.Header())
@ -94,11 +112,11 @@ var _ = Describe("Proxy Suite", func() {
Expect(request).To(Equal(in.response.request)) Expect(request).To(Equal(in.response.request))
}, },
Entry("with a request to the HTTP service", &proxyTableInput{ Entry("with a request to the HTTP service", &proxyTableInput{
target: "http://example.localhost/http/1234", target: "http://example.localhost/http/1234",
upstream: "http-backend",
response: testHTTPResponse{ response: testHTTPResponse{
code: 200, code: 200,
header: map[string][]string{ header: map[string][]string{
gapUpstream: {"http-backend"},
contentType: {applicationJSON}, contentType: {applicationJSON},
}, },
request: testHTTPRequest{ request: testHTTPRequest{
@ -115,33 +133,31 @@ var _ = Describe("Proxy Suite", func() {
}, },
}), }),
Entry("with a request to the File backend", &proxyTableInput{ Entry("with a request to the File backend", &proxyTableInput{
target: "http://example.localhost/files/foo", target: "http://example.localhost/files/foo",
upstream: "file-backend",
response: testHTTPResponse{ response: testHTTPResponse{
code: 200, code: 200,
header: map[string][]string{ header: map[string][]string{
contentType: {textPlainUTF8}, contentType: {textPlainUTF8},
gapUpstream: {"file-backend"},
}, },
raw: "foo", raw: "foo",
}, },
}), }),
Entry("with a request to the Static backend", &proxyTableInput{ Entry("with a request to the Static backend", &proxyTableInput{
target: "http://example.localhost/static/bar", target: "http://example.localhost/static/bar",
upstream: "static-backend",
response: testHTTPResponse{ response: testHTTPResponse{
code: 200, code: 200,
header: map[string][]string{ header: map[string][]string{},
gapUpstream: {"static-backend"}, raw: "Authenticated",
},
raw: "Authenticated",
}, },
}), }),
Entry("with a request to the bad HTTP backend", &proxyTableInput{ Entry("with a request to the bad HTTP backend", &proxyTableInput{
target: "http://example.localhost/bad-http/bad", target: "http://example.localhost/bad-http/bad",
upstream: "bad-http-backend",
response: testHTTPResponse{ response: testHTTPResponse{
code: 502, code: 502,
header: map[string][]string{ header: map[string][]string{},
gapUpstream: {"bad-http-backend"},
},
// This tests the error handler // This tests the error handler
raw: "Bad Gateway\nError proxying to upstream server\nprefix", raw: "Bad Gateway\nError proxying to upstream server\nprefix",
}, },
@ -158,13 +174,12 @@ var _ = Describe("Proxy Suite", func() {
}, },
}), }),
Entry("with a request to the to backend registered to a single path", &proxyTableInput{ Entry("with a request to the to backend registered to a single path", &proxyTableInput{
target: "http://example.localhost/single-path", target: "http://example.localhost/single-path",
upstream: "single-path-backend",
response: testHTTPResponse{ response: testHTTPResponse{
code: 200, code: 200,
header: map[string][]string{ header: map[string][]string{},
gapUpstream: {"single-path-backend"}, raw: "Authenticated",
},
raw: "Authenticated",
}, },
}), }),
Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{ Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{

View File

@ -3,6 +3,8 @@ package upstream
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
) )
const defaultStaticResponseCode = 200 const defaultStaticResponseCode = 200
@ -24,7 +26,12 @@ type staticResponseHandler struct {
// ServeHTTP serves a static response. // ServeHTTP serves a static response.
func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("GAP-Upstream-Address", s.upstream) scope := middleware.GetRequestScope(req)
// If scope is nil, this will panic.
// A scope should always be injected before this handler is called.
scope.Upstream = s.upstream
rw.WriteHeader(s.code) rw.WriteHeader(s.code)
fmt.Fprintf(rw, "Authenticated") fmt.Fprintf(rw, "Authenticated")
} }

View File

@ -6,6 +6,9 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -16,8 +19,8 @@ var _ = Describe("Static Response Suite", func() {
var id string var id string
BeforeEach(func() { BeforeEach(func() {
// Generate a random id before each test to check the GAP-Upstream-Address // Generate a random id before each test to check the upstream
// is being set correctly // is being set correctly in the scope
idBytes := make([]byte, 16) idBytes := make([]byte, 16)
_, err := io.ReadFull(rand.Reader, idBytes) _, err := io.ReadFull(rand.Reader, idBytes)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -37,13 +40,24 @@ var _ = Describe("Static Response Suite", func() {
if in.staticCode != 0 { if in.staticCode != 0 {
code = &in.staticCode code = &in.staticCode
} }
handler := newStaticResponseHandler(id, code)
var scope *middlewareapi.RequestScope
// Extract the scope so that we can see that the upstream has been set
// correctly
extractScope := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope = middleware.GetRequestScope(req)
next.ServeHTTP(rw, req)
})
}
handler := alice.New(middleware.NewScope(), extractScope).Then(newStaticResponseHandler(id, code))
req := httptest.NewRequest("", in.requestPath, nil) req := httptest.NewRequest("", in.requestPath, nil)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req) handler.ServeHTTP(rw, req)
Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) Expect(scope.Upstream).To(Equal(id))
Expect(rw.Code).To(Equal(in.expectedCode)) Expect(rw.Code).To(Equal(in.expectedCode))
Expect(rw.Body.String()).To(Equal(in.expectedBody)) Expect(rw.Body.String()).To(Equal(in.expectedBody))
}, },

View File

@ -58,7 +58,6 @@ const (
acceptEncoding = "Accept-Encoding" acceptEncoding = "Accept-Encoding"
applicationJSON = "application/json" applicationJSON = "application/json"
textPlainUTF8 = "text/plain; charset=utf-8" textPlainUTF8 = "text/plain; charset=utf-8"
gapUpstream = "Gap-Upstream-Address"
gapAuth = "Gap-Auth" gapAuth = "Gap-Auth"
gapSignature = "Gap-Signature" gapSignature = "Gap-Signature"
) )