1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2024-11-28 09:08:44 +02:00
oauth2-proxy/pkg/upstream/http_test.go
Jack Henschel 7a27cb04df Implement configurable timeout for upstream connections
Signed-off-by: Jack Henschel <jack.henschel@cern.ch>
2022-05-18 11:41:17 +01:00

500 lines
15 KiB
Go

package upstream
import (
"bytes"
"crypto"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"time"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
"golang.org/x/net/websocket"
)
var _ = Describe("HTTP Upstream Suite", func() {
defaultFlushInterval := options.Duration(options.DefaultUpstreamFlushInterval)
defaultTimeout := options.Duration(options.DefaultUpstreamTimeout)
truth := true
falsum := false
type httpUpstreamTableInput struct {
id string
serverAddr *string
target string
method string
body []byte
passUpstreamHostHeader bool
signatureData *options.SignatureData
existingHeaders map[string]string
expectedResponse testHTTPResponse
expectedUpstream string
errorHandler ProxyErrorHandler
}
DescribeTable("HTTP Upstream ServeHTTP",
func(in *httpUpstreamTableInput) {
buf := bytes.NewBuffer(in.body)
req := httptest.NewRequest(in.method, in.target, buf)
// Don't mock the remote Address
req.RemoteAddr = ""
for key, value := range in.existingHeaders {
req.Header.Add(key, value)
}
if host := req.Header.Get("Host"); host != "" {
req.Host = host
}
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
rw := httptest.NewRecorder()
flush := options.Duration(1 * time.Second)
timeout := options.Duration(options.DefaultUpstreamTimeout)
upstream := options.Upstream{
ID: in.id,
PassHostHeader: &in.passUpstreamHostHeader,
ProxyWebSockets: &falsum,
InsecureSkipTLSVerify: false,
FlushInterval: &flush,
Timeout: &timeout,
}
Expect(in.serverAddr).ToNot(BeNil())
u, err := url.Parse(*in.serverAddr)
Expect(err).ToNot(HaveOccurred())
handler := newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler)
handler.ServeHTTP(rw, req)
Expect(rw.Code).To(Equal(in.expectedResponse.code))
scope := middlewareapi.GetRequestScope(req)
Expect(scope.Upstream).To(Equal(in.expectedUpstream))
// Delete extra headers that aren't relevant to tests
testSanitizeResponseHeader(rw.Header())
Expect(rw.Header()).To(Equal(in.expectedResponse.header))
body := rw.Body.Bytes()
if in.expectedResponse.raw != "" || rw.Code != http.StatusOK {
Expect(string(body)).To(Equal(in.expectedResponse.raw))
return
}
// Compare the reflected request to the upstream
request := testHTTPRequest{}
Expect(json.Unmarshal(body, &request)).To(Succeed())
testSanitizeRequestHeader(request.Header)
Expect(request).To(Equal(in.expectedResponse.request))
},
Entry("request a path on the server", &httpUpstreamTableInput{
id: "default",
serverAddr: &serverAddr,
target: "http://example.localhost/foo",
method: "GET",
body: []byte{},
errorHandler: nil,
expectedResponse: testHTTPResponse{
code: 200,
header: map[string][]string{
contentType: {applicationJSON},
},
request: testHTTPRequest{
Method: "GET",
URL: "http://example.localhost/foo",
Header: map[string][]string{},
Body: []byte{},
Host: "example.localhost",
RequestURI: "http://example.localhost/foo",
},
},
expectedUpstream: "default",
}),
Entry("request a path with encoded slashes", &httpUpstreamTableInput{
id: "encodedSlashes",
serverAddr: &serverAddr,
target: "http://example.localhost/foo%2fbar/?baz=1",
method: "GET",
body: []byte{},
errorHandler: nil,
expectedResponse: testHTTPResponse{
code: 200,
header: map[string][]string{
contentType: {applicationJSON},
},
request: testHTTPRequest{
Method: "GET",
URL: "http://example.localhost/foo%2fbar/?baz=1",
Header: map[string][]string{},
Body: []byte{},
Host: "example.localhost",
RequestURI: "http://example.localhost/foo%2fbar/?baz=1",
},
},
expectedUpstream: "encodedSlashes",
}),
Entry("request a path with an empty query string", &httpUpstreamTableInput{
id: "default",
serverAddr: &serverAddr,
target: "http://example.localhost/foo?",
method: "GET",
body: []byte{},
errorHandler: nil,
expectedResponse: testHTTPResponse{
code: 200,
header: map[string][]string{
contentType: {applicationJSON},
},
request: testHTTPRequest{
Method: "GET",
URL: "http://example.localhost/foo?",
Header: map[string][]string{},
Body: []byte{},
Host: "example.localhost",
RequestURI: "http://example.localhost/foo?",
},
},
expectedUpstream: "default",
}),
Entry("when the request has a body", &httpUpstreamTableInput{
id: "requestWithBody",
serverAddr: &serverAddr,
target: "http://example.localhost/withBody",
method: "POST",
body: []byte("body"),
errorHandler: nil,
expectedResponse: testHTTPResponse{
code: 200,
header: map[string][]string{
contentType: {applicationJSON},
},
request: testHTTPRequest{
Method: "POST",
URL: "http://example.localhost/withBody",
Header: map[string][]string{
contentLength: {"4"},
},
Body: []byte("body"),
Host: "example.localhost",
RequestURI: "http://example.localhost/withBody",
},
},
expectedUpstream: "requestWithBody",
}),
Entry("when the upstream is unavailable", &httpUpstreamTableInput{
id: "unavailableUpstream",
serverAddr: &invalidServer,
target: "http://example.localhost/unavailableUpstream",
method: "GET",
body: []byte{},
errorHandler: nil,
expectedResponse: testHTTPResponse{
code: 502,
header: map[string][]string{},
request: testHTTPRequest{},
},
expectedUpstream: "unavailableUpstream",
}),
Entry("when the upstream is unavailable and an error handler is set", &httpUpstreamTableInput{
id: "withErrorHandler",
serverAddr: &invalidServer,
target: "http://example.localhost/withErrorHandler",
method: "GET",
body: []byte{},
errorHandler: func(rw http.ResponseWriter, _ *http.Request, _ error) {
rw.WriteHeader(502)
rw.Write([]byte("error"))
},
expectedResponse: testHTTPResponse{
code: 502,
header: map[string][]string{},
raw: "error",
request: testHTTPRequest{},
},
expectedUpstream: "withErrorHandler",
}),
Entry("with a signature", &httpUpstreamTableInput{
id: "withSignature",
serverAddr: &serverAddr,
target: "http://example.localhost/withSignature",
method: "GET",
body: []byte{},
signatureData: &options.SignatureData{
Hash: crypto.SHA256,
Key: "key",
},
errorHandler: nil,
expectedResponse: testHTTPResponse{
code: 200,
header: map[string][]string{
contentType: {applicationJSON},
},
request: testHTTPRequest{
Method: "GET",
URL: "http://example.localhost/withSignature",
Header: map[string][]string{
gapAuth: {""},
gapSignature: {"sha256 osMWI8Rr0Zr5HgNq6wakrgJITVJQMmFN1fXCesrqrmM="},
},
Body: []byte{},
Host: "example.localhost",
RequestURI: "http://example.localhost/withSignature",
},
},
expectedUpstream: "withSignature",
}),
Entry("with existing headers", &httpUpstreamTableInput{
id: "existingHeaders",
serverAddr: &serverAddr,
target: "http://example.localhost/existingHeaders",
method: "GET",
body: []byte{},
errorHandler: nil,
existingHeaders: map[string]string{
"Header1": "value1",
"Header2": "value2",
},
expectedResponse: testHTTPResponse{
code: 200,
header: map[string][]string{
contentType: {applicationJSON},
},
request: testHTTPRequest{
Method: "GET",
URL: "http://example.localhost/existingHeaders",
Header: map[string][]string{
"Header1": {"value1"},
"Header2": {"value2"},
},
Body: []byte{},
Host: "example.localhost",
RequestURI: "http://example.localhost/existingHeaders",
},
},
expectedUpstream: "existingHeaders",
}),
Entry("when passing the existing host header", &httpUpstreamTableInput{
id: "passExistingHostHeader",
serverAddr: &serverAddr,
target: "/existingHostHeader",
method: "GET",
body: []byte{},
errorHandler: nil,
passUpstreamHostHeader: true,
existingHeaders: map[string]string{
"Host": "existing-host",
},
expectedResponse: testHTTPResponse{
code: 200,
header: map[string][]string{
contentType: {applicationJSON},
},
request: testHTTPRequest{
Method: "GET",
URL: "/existingHostHeader",
Header: map[string][]string{},
Body: []byte{},
Host: "existing-host",
RequestURI: "/existingHostHeader",
},
},
expectedUpstream: "passExistingHostHeader",
}),
)
It("ServeHTTP, when not passing a host header", func() {
req := httptest.NewRequest("", "http://example.localhost/foo", nil)
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{})
rw := httptest.NewRecorder()
upstream := options.Upstream{
ID: "noPassHost",
PassHostHeader: &falsum,
ProxyWebSockets: &falsum,
InsecureSkipTLSVerify: false,
FlushInterval: &defaultFlushInterval,
Timeout: &defaultTimeout,
}
u, err := url.Parse(serverAddr)
Expect(err).ToNot(HaveOccurred())
handler := newHTTPUpstreamProxy(upstream, u, nil, nil)
httpUpstream, ok := handler.(*httpUpstreamProxy)
Expect(ok).To(BeTrue())
// Override the handler to just run the director and not actually send the request
requestInterceptor := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
proxy, ok := h.(*httputil.ReverseProxy)
Expect(ok).To(BeTrue())
proxy.Director(req)
})
}
httpUpstream.handler = requestInterceptor(httpUpstream.handler)
httpUpstream.ServeHTTP(rw, req)
Expect(req.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://")))
})
type newUpstreamTableInput struct {
proxyWebSockets bool
flushInterval options.Duration
skipVerify bool
sigData *options.SignatureData
errorHandler func(http.ResponseWriter, *http.Request, error)
timeout options.Duration
}
DescribeTable("newHTTPUpstreamProxy",
func(in *newUpstreamTableInput) {
u, err := url.Parse("http://upstream:1234")
Expect(err).ToNot(HaveOccurred())
upstream := options.Upstream{
ID: "foo123",
FlushInterval: &in.flushInterval,
InsecureSkipTLSVerify: in.skipVerify,
ProxyWebSockets: &in.proxyWebSockets,
Timeout: &in.timeout,
}
handler := newHTTPUpstreamProxy(upstream, u, in.sigData, in.errorHandler)
upstreamProxy, ok := handler.(*httpUpstreamProxy)
Expect(ok).To(BeTrue())
Expect(upstreamProxy.auth != nil).To(Equal(in.sigData != nil))
Expect(upstreamProxy.wsHandler != nil).To(Equal(in.proxyWebSockets))
Expect(upstreamProxy.upstream).To(Equal(upstream.ID))
Expect(upstreamProxy.handler).ToNot(BeNil())
proxy, ok := upstreamProxy.handler.(*httputil.ReverseProxy)
Expect(ok).To(BeTrue())
Expect(proxy.FlushInterval).To(Equal(in.flushInterval.Duration()))
transport, ok := proxy.Transport.(*http.Transport)
Expect(ok).To(BeTrue())
Expect(transport.ResponseHeaderTimeout).To(Equal(in.timeout.Duration()))
Expect(proxy.ErrorHandler != nil).To(Equal(in.errorHandler != nil))
if in.skipVerify {
Expect(transport.TLSClientConfig.InsecureSkipVerify).To(Equal(true))
}
},
Entry("with proxy websockets", &newUpstreamTableInput{
proxyWebSockets: true,
flushInterval: defaultFlushInterval,
skipVerify: false,
sigData: nil,
errorHandler: nil,
timeout: defaultTimeout,
}),
Entry("with a non standard flush interval", &newUpstreamTableInput{
proxyWebSockets: false,
flushInterval: options.Duration(5 * time.Second),
skipVerify: false,
sigData: nil,
errorHandler: nil,
timeout: defaultTimeout,
}),
Entry("with a InsecureSkipTLSVerify", &newUpstreamTableInput{
proxyWebSockets: false,
flushInterval: defaultFlushInterval,
skipVerify: true,
sigData: nil,
errorHandler: nil,
timeout: defaultTimeout,
}),
Entry("with a SignatureData", &newUpstreamTableInput{
proxyWebSockets: false,
flushInterval: defaultFlushInterval,
skipVerify: false,
sigData: &options.SignatureData{Hash: crypto.SHA256, Key: "secret"},
errorHandler: nil,
timeout: defaultTimeout,
}),
Entry("with an error handler", &newUpstreamTableInput{
proxyWebSockets: false,
flushInterval: defaultFlushInterval,
skipVerify: false,
sigData: nil,
errorHandler: func(rw http.ResponseWriter, req *http.Request, arg3 error) {
rw.WriteHeader(502)
},
timeout: defaultTimeout,
}),
Entry("with a non-default timeout", &newUpstreamTableInput{
proxyWebSockets: false,
flushInterval: defaultFlushInterval,
skipVerify: false,
sigData: nil,
errorHandler: nil,
timeout: options.Duration(5 * time.Second),
}),
)
Context("with a websocket proxy", func() {
var proxyServer *httptest.Server
BeforeEach(func() {
flush := options.Duration(1 * time.Second)
timeout := options.Duration(options.DefaultUpstreamTimeout)
upstream := options.Upstream{
ID: "websocketProxy",
PassHostHeader: &truth,
ProxyWebSockets: &truth,
InsecureSkipTLSVerify: false,
FlushInterval: &flush,
Timeout: &timeout,
}
u, err := url.Parse(serverAddr)
Expect(err).ToNot(HaveOccurred())
handler := newHTTPUpstreamProxy(upstream, u, nil, nil)
proxyServer = httptest.NewServer(middleware.NewScope(false, "X-Request-Id")(handler))
})
AfterEach(func() {
proxyServer.Close()
})
It("will proxy websockets", func() {
origin := "http://example.localhost"
message := "Hello, world!"
proxyURL, err := url.Parse(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String()))
Expect(err).ToNot(HaveOccurred())
wsAddr := fmt.Sprintf("ws://%s/", proxyURL.Host)
ws, err := websocket.Dial(wsAddr, "", origin)
Expect(err).ToNot(HaveOccurred())
Expect(websocket.Message.Send(ws, []byte(message))).To(Succeed())
var response testWebSocketResponse
Expect(websocket.JSON.Receive(ws, &response)).To(Succeed())
Expect(response).To(Equal(testWebSocketResponse{
Message: message,
Origin: origin,
}))
})
It("will proxy HTTP requests", func() {
response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String()))
Expect(err).ToNot(HaveOccurred())
Expect(response.StatusCode).To(Equal(200))
})
})
})