mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-01-10 04:18:14 +02:00
481 lines
14 KiB
Go
481 lines
14 KiB
Go
package upstream
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/ginkgo/extensions/table"
|
|
. "github.com/onsi/gomega"
|
|
"golang.org/x/net/websocket"
|
|
)
|
|
|
|
var _ = Describe("HTTP Upstream Suite", func() {
|
|
|
|
const flushInterval5s = options.Duration(5 * time.Second)
|
|
const flushInterval1s = options.Duration(1 * time.Second)
|
|
truth := true
|
|
falsum := false
|
|
|
|
type httpUpstreamTableInput struct {
|
|
id string
|
|
serverAddr *string
|
|
target string
|
|
method string
|
|
body []byte
|
|
passUpstreamHostHeader bool
|
|
signatureData *options.SignatureData
|
|
existingHeaders map[string]string
|
|
expectedResponse testHTTPResponse
|
|
expectedUpstream string
|
|
errorHandler ProxyErrorHandler
|
|
}
|
|
|
|
DescribeTable("HTTP Upstream ServeHTTP",
|
|
func(in *httpUpstreamTableInput) {
|
|
buf := bytes.NewBuffer(in.body)
|
|
req := httptest.NewRequest(in.method, in.target, buf)
|
|
// Don't mock the remote Address
|
|
req.RemoteAddr = ""
|
|
|
|
for key, value := range in.existingHeaders {
|
|
req.Header.Add(key, value)
|
|
}
|
|
if host := req.Header.Get("Host"); host != "" {
|
|
req.Host = host
|
|
}
|
|
|
|
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
|
|
rw := httptest.NewRecorder()
|
|
|
|
flush := options.Duration(1 * time.Second)
|
|
|
|
upstream := options.Upstream{
|
|
ID: in.id,
|
|
PassHostHeader: &in.passUpstreamHostHeader,
|
|
ProxyWebSockets: &falsum,
|
|
InsecureSkipTLSVerify: false,
|
|
FlushInterval: &flush,
|
|
}
|
|
|
|
Expect(in.serverAddr).ToNot(BeNil())
|
|
u, err := url.Parse(*in.serverAddr)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
handler := newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler)
|
|
handler.ServeHTTP(rw, req)
|
|
|
|
Expect(rw.Code).To(Equal(in.expectedResponse.code))
|
|
|
|
scope := middlewareapi.GetRequestScope(req)
|
|
Expect(scope.Upstream).To(Equal(in.expectedUpstream))
|
|
|
|
// Delete extra headers that aren't relevant to tests
|
|
testSanitizeResponseHeader(rw.Header())
|
|
Expect(rw.Header()).To(Equal(in.expectedResponse.header))
|
|
|
|
body := rw.Body.Bytes()
|
|
if in.expectedResponse.raw != "" || rw.Code != http.StatusOK {
|
|
Expect(string(body)).To(Equal(in.expectedResponse.raw))
|
|
return
|
|
}
|
|
|
|
// Compare the reflected request to the upstream
|
|
request := testHTTPRequest{}
|
|
Expect(json.Unmarshal(body, &request)).To(Succeed())
|
|
testSanitizeRequestHeader(request.Header)
|
|
Expect(request).To(Equal(in.expectedResponse.request))
|
|
},
|
|
Entry("request a path on the server", &httpUpstreamTableInput{
|
|
id: "default",
|
|
serverAddr: &serverAddr,
|
|
target: "http://example.localhost/foo",
|
|
method: "GET",
|
|
body: []byte{},
|
|
errorHandler: nil,
|
|
expectedResponse: testHTTPResponse{
|
|
code: 200,
|
|
header: map[string][]string{
|
|
contentType: {applicationJSON},
|
|
},
|
|
request: testHTTPRequest{
|
|
Method: "GET",
|
|
URL: "http://example.localhost/foo",
|
|
Header: map[string][]string{},
|
|
Body: []byte{},
|
|
Host: "example.localhost",
|
|
RequestURI: "http://example.localhost/foo",
|
|
},
|
|
},
|
|
expectedUpstream: "default",
|
|
}),
|
|
Entry("request a path with encoded slashes", &httpUpstreamTableInput{
|
|
id: "encodedSlashes",
|
|
serverAddr: &serverAddr,
|
|
target: "http://example.localhost/foo%2fbar/?baz=1",
|
|
method: "GET",
|
|
body: []byte{},
|
|
errorHandler: nil,
|
|
expectedResponse: testHTTPResponse{
|
|
code: 200,
|
|
header: map[string][]string{
|
|
contentType: {applicationJSON},
|
|
},
|
|
request: testHTTPRequest{
|
|
Method: "GET",
|
|
URL: "http://example.localhost/foo%2fbar/?baz=1",
|
|
Header: map[string][]string{},
|
|
Body: []byte{},
|
|
Host: "example.localhost",
|
|
RequestURI: "http://example.localhost/foo%2fbar/?baz=1",
|
|
},
|
|
},
|
|
expectedUpstream: "encodedSlashes",
|
|
}),
|
|
Entry("request a path with an empty query string", &httpUpstreamTableInput{
|
|
id: "default",
|
|
serverAddr: &serverAddr,
|
|
target: "http://example.localhost/foo?",
|
|
method: "GET",
|
|
body: []byte{},
|
|
errorHandler: nil,
|
|
expectedResponse: testHTTPResponse{
|
|
code: 200,
|
|
header: map[string][]string{
|
|
contentType: {applicationJSON},
|
|
},
|
|
request: testHTTPRequest{
|
|
Method: "GET",
|
|
URL: "http://example.localhost/foo?",
|
|
Header: map[string][]string{},
|
|
Body: []byte{},
|
|
Host: "example.localhost",
|
|
RequestURI: "http://example.localhost/foo?",
|
|
},
|
|
},
|
|
expectedUpstream: "default",
|
|
}),
|
|
Entry("when the request has a body", &httpUpstreamTableInput{
|
|
id: "requestWithBody",
|
|
serverAddr: &serverAddr,
|
|
target: "http://example.localhost/withBody",
|
|
method: "POST",
|
|
body: []byte("body"),
|
|
errorHandler: nil,
|
|
expectedResponse: testHTTPResponse{
|
|
code: 200,
|
|
header: map[string][]string{
|
|
contentType: {applicationJSON},
|
|
},
|
|
request: testHTTPRequest{
|
|
Method: "POST",
|
|
URL: "http://example.localhost/withBody",
|
|
Header: map[string][]string{
|
|
contentLength: {"4"},
|
|
},
|
|
Body: []byte("body"),
|
|
Host: "example.localhost",
|
|
RequestURI: "http://example.localhost/withBody",
|
|
},
|
|
},
|
|
expectedUpstream: "requestWithBody",
|
|
}),
|
|
Entry("when the upstream is unavailable", &httpUpstreamTableInput{
|
|
id: "unavailableUpstream",
|
|
serverAddr: &invalidServer,
|
|
target: "http://example.localhost/unavailableUpstream",
|
|
method: "GET",
|
|
body: []byte{},
|
|
errorHandler: nil,
|
|
expectedResponse: testHTTPResponse{
|
|
code: 502,
|
|
header: map[string][]string{},
|
|
request: testHTTPRequest{},
|
|
},
|
|
expectedUpstream: "unavailableUpstream",
|
|
}),
|
|
Entry("when the upstream is unavailable and an error handler is set", &httpUpstreamTableInput{
|
|
id: "withErrorHandler",
|
|
serverAddr: &invalidServer,
|
|
target: "http://example.localhost/withErrorHandler",
|
|
method: "GET",
|
|
body: []byte{},
|
|
errorHandler: func(rw http.ResponseWriter, _ *http.Request, _ error) {
|
|
rw.WriteHeader(502)
|
|
rw.Write([]byte("error"))
|
|
},
|
|
expectedResponse: testHTTPResponse{
|
|
code: 502,
|
|
header: map[string][]string{},
|
|
raw: "error",
|
|
request: testHTTPRequest{},
|
|
},
|
|
expectedUpstream: "withErrorHandler",
|
|
}),
|
|
Entry("with a signature", &httpUpstreamTableInput{
|
|
id: "withSignature",
|
|
serverAddr: &serverAddr,
|
|
target: "http://example.localhost/withSignature",
|
|
method: "GET",
|
|
body: []byte{},
|
|
signatureData: &options.SignatureData{
|
|
Hash: crypto.SHA256,
|
|
Key: "key",
|
|
},
|
|
errorHandler: nil,
|
|
expectedResponse: testHTTPResponse{
|
|
code: 200,
|
|
header: map[string][]string{
|
|
contentType: {applicationJSON},
|
|
},
|
|
request: testHTTPRequest{
|
|
Method: "GET",
|
|
URL: "http://example.localhost/withSignature",
|
|
Header: map[string][]string{
|
|
gapAuth: {""},
|
|
gapSignature: {"sha256 osMWI8Rr0Zr5HgNq6wakrgJITVJQMmFN1fXCesrqrmM="},
|
|
},
|
|
Body: []byte{},
|
|
Host: "example.localhost",
|
|
RequestURI: "http://example.localhost/withSignature",
|
|
},
|
|
},
|
|
expectedUpstream: "withSignature",
|
|
}),
|
|
Entry("with existing headers", &httpUpstreamTableInput{
|
|
id: "existingHeaders",
|
|
serverAddr: &serverAddr,
|
|
target: "http://example.localhost/existingHeaders",
|
|
method: "GET",
|
|
body: []byte{},
|
|
errorHandler: nil,
|
|
existingHeaders: map[string]string{
|
|
"Header1": "value1",
|
|
"Header2": "value2",
|
|
},
|
|
expectedResponse: testHTTPResponse{
|
|
code: 200,
|
|
header: map[string][]string{
|
|
contentType: {applicationJSON},
|
|
},
|
|
request: testHTTPRequest{
|
|
Method: "GET",
|
|
URL: "http://example.localhost/existingHeaders",
|
|
Header: map[string][]string{
|
|
"Header1": {"value1"},
|
|
"Header2": {"value2"},
|
|
},
|
|
Body: []byte{},
|
|
Host: "example.localhost",
|
|
RequestURI: "http://example.localhost/existingHeaders",
|
|
},
|
|
},
|
|
expectedUpstream: "existingHeaders",
|
|
}),
|
|
Entry("when passing the existing host header", &httpUpstreamTableInput{
|
|
id: "passExistingHostHeader",
|
|
serverAddr: &serverAddr,
|
|
target: "/existingHostHeader",
|
|
method: "GET",
|
|
body: []byte{},
|
|
errorHandler: nil,
|
|
passUpstreamHostHeader: true,
|
|
existingHeaders: map[string]string{
|
|
"Host": "existing-host",
|
|
},
|
|
expectedResponse: testHTTPResponse{
|
|
code: 200,
|
|
header: map[string][]string{
|
|
contentType: {applicationJSON},
|
|
},
|
|
request: testHTTPRequest{
|
|
Method: "GET",
|
|
URL: "/existingHostHeader",
|
|
Header: map[string][]string{},
|
|
Body: []byte{},
|
|
Host: "existing-host",
|
|
RequestURI: "/existingHostHeader",
|
|
},
|
|
},
|
|
expectedUpstream: "passExistingHostHeader",
|
|
}),
|
|
)
|
|
|
|
It("ServeHTTP, when not passing a host header", func() {
|
|
req := httptest.NewRequest("", "http://example.localhost/foo", nil)
|
|
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
|
|
rw := httptest.NewRecorder()
|
|
|
|
flush := options.Duration(1 * time.Second)
|
|
upstream := options.Upstream{
|
|
ID: "noPassHost",
|
|
PassHostHeader: &falsum,
|
|
ProxyWebSockets: &falsum,
|
|
InsecureSkipTLSVerify: false,
|
|
FlushInterval: &flush,
|
|
}
|
|
|
|
u, err := url.Parse(serverAddr)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
handler := newHTTPUpstreamProxy(upstream, u, nil, nil)
|
|
httpUpstream, ok := handler.(*httpUpstreamProxy)
|
|
Expect(ok).To(BeTrue())
|
|
|
|
// Override the handler to just run the director and not actually send the request
|
|
requestInterceptor := func(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
|
|
proxy, ok := h.(*httputil.ReverseProxy)
|
|
Expect(ok).To(BeTrue())
|
|
proxy.Director(req)
|
|
})
|
|
}
|
|
httpUpstream.handler = requestInterceptor(httpUpstream.handler)
|
|
|
|
httpUpstream.ServeHTTP(rw, req)
|
|
Expect(req.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://")))
|
|
})
|
|
|
|
type newUpstreamTableInput struct {
|
|
proxyWebSockets bool
|
|
flushInterval options.Duration
|
|
skipVerify bool
|
|
sigData *options.SignatureData
|
|
errorHandler func(http.ResponseWriter, *http.Request, error)
|
|
}
|
|
|
|
DescribeTable("newHTTPUpstreamProxy",
|
|
func(in *newUpstreamTableInput) {
|
|
u, err := url.Parse("http://upstream:1234")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
upstream := options.Upstream{
|
|
ID: "foo123",
|
|
FlushInterval: &in.flushInterval,
|
|
InsecureSkipTLSVerify: in.skipVerify,
|
|
ProxyWebSockets: &in.proxyWebSockets,
|
|
}
|
|
|
|
handler := newHTTPUpstreamProxy(upstream, u, in.sigData, in.errorHandler)
|
|
upstreamProxy, ok := handler.(*httpUpstreamProxy)
|
|
Expect(ok).To(BeTrue())
|
|
|
|
Expect(upstreamProxy.auth != nil).To(Equal(in.sigData != nil))
|
|
Expect(upstreamProxy.wsHandler != nil).To(Equal(in.proxyWebSockets))
|
|
Expect(upstreamProxy.upstream).To(Equal(upstream.ID))
|
|
Expect(upstreamProxy.handler).ToNot(BeNil())
|
|
|
|
proxy, ok := upstreamProxy.handler.(*httputil.ReverseProxy)
|
|
Expect(ok).To(BeTrue())
|
|
Expect(proxy.FlushInterval).To(Equal(in.flushInterval.Duration()))
|
|
Expect(proxy.ErrorHandler != nil).To(Equal(in.errorHandler != nil))
|
|
if in.skipVerify {
|
|
Expect(proxy.Transport).To(Equal(&http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}))
|
|
}
|
|
},
|
|
Entry("with proxy websockets", &newUpstreamTableInput{
|
|
proxyWebSockets: true,
|
|
flushInterval: flushInterval1s,
|
|
skipVerify: false,
|
|
sigData: nil,
|
|
errorHandler: nil,
|
|
}),
|
|
Entry("with a non standard flush interval", &newUpstreamTableInput{
|
|
proxyWebSockets: false,
|
|
flushInterval: flushInterval5s,
|
|
skipVerify: false,
|
|
sigData: nil,
|
|
errorHandler: nil,
|
|
}),
|
|
Entry("with a InsecureSkipTLSVerify", &newUpstreamTableInput{
|
|
proxyWebSockets: false,
|
|
flushInterval: flushInterval1s,
|
|
skipVerify: true,
|
|
sigData: nil,
|
|
errorHandler: nil,
|
|
}),
|
|
Entry("with a SignatureData", &newUpstreamTableInput{
|
|
proxyWebSockets: false,
|
|
flushInterval: flushInterval1s,
|
|
skipVerify: false,
|
|
sigData: &options.SignatureData{Hash: crypto.SHA256, Key: "secret"},
|
|
errorHandler: nil,
|
|
}),
|
|
Entry("with an error handler", &newUpstreamTableInput{
|
|
proxyWebSockets: false,
|
|
flushInterval: flushInterval1s,
|
|
skipVerify: false,
|
|
sigData: nil,
|
|
errorHandler: func(rw http.ResponseWriter, req *http.Request, arg3 error) {
|
|
rw.WriteHeader(502)
|
|
},
|
|
}),
|
|
)
|
|
|
|
Context("with a websocket proxy", func() {
|
|
var proxyServer *httptest.Server
|
|
|
|
BeforeEach(func() {
|
|
flush := options.Duration(1 * time.Second)
|
|
upstream := options.Upstream{
|
|
ID: "websocketProxy",
|
|
PassHostHeader: &truth,
|
|
ProxyWebSockets: &truth,
|
|
InsecureSkipTLSVerify: false,
|
|
FlushInterval: &flush,
|
|
}
|
|
|
|
u, err := url.Parse(serverAddr)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
handler := newHTTPUpstreamProxy(upstream, u, nil, nil)
|
|
|
|
proxyServer = httptest.NewServer(middleware.NewScope(false, "X-Request-Id")(handler))
|
|
})
|
|
|
|
AfterEach(func() {
|
|
proxyServer.Close()
|
|
})
|
|
|
|
It("will proxy websockets", func() {
|
|
origin := "http://example.localhost"
|
|
message := "Hello, world!"
|
|
|
|
proxyURL, err := url.Parse(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String()))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
wsAddr := fmt.Sprintf("ws://%s/", proxyURL.Host)
|
|
ws, err := websocket.Dial(wsAddr, "", origin)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(websocket.Message.Send(ws, []byte(message))).To(Succeed())
|
|
var response testWebSocketResponse
|
|
Expect(websocket.JSON.Receive(ws, &response)).To(Succeed())
|
|
Expect(response).To(Equal(testWebSocketResponse{
|
|
Message: message,
|
|
Origin: origin,
|
|
}))
|
|
})
|
|
|
|
It("will proxy HTTP requests", func() {
|
|
response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String()))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(response.StatusCode).To(Equal(200))
|
|
})
|
|
})
|
|
})
|