mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-02-09 13:46:51 +02:00
Add tests for upstream package
This commit is contained in:
parent
fa8e1ee033
commit
5b95ed3033
58
pkg/upstream/file_test.go
Normal file
58
pkg/upstream/file_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("File Server Suite", func() {
|
||||
var dir string
|
||||
var handler http.Handler
|
||||
var id string
|
||||
|
||||
const (
|
||||
foo = "foo"
|
||||
bar = "bar"
|
||||
baz = "baz"
|
||||
pageNotFound = "404 page not found\n"
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
// Generate a random id before each test to check the GAP-Upstream-Address
|
||||
// is being set correctly
|
||||
idBytes := make([]byte, 16)
|
||||
_, err := io.ReadFull(rand.Reader, idBytes)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
id = string(idBytes)
|
||||
|
||||
handler = newFileServer(id, "/files", filesDir)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(os.RemoveAll(dir)).To(Succeed())
|
||||
})
|
||||
|
||||
DescribeTable("fileServer ServeHTTP",
|
||||
func(requestPath string, expectedResponseCode int, expectedBody string) {
|
||||
req := httptest.NewRequest("", requestPath, nil)
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id))
|
||||
Expect(rw.Code).To(Equal(expectedResponseCode))
|
||||
Expect(rw.Body.String()).To(Equal(expectedBody))
|
||||
},
|
||||
Entry("for file foo", "/files/foo", 200, foo),
|
||||
Entry("for file bar", "/files/bar", 200, bar),
|
||||
Entry("for file foo/baz", "/files/subdir/baz", 200, baz),
|
||||
Entry("for a non-existent file inside the path", "/files/baz", 404, pageNotFound),
|
||||
Entry("for a non-existent file oustide the path", "/baz", 404, pageNotFound),
|
||||
)
|
||||
})
|
@ -6,6 +6,7 @@ import (
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
@ -96,7 +97,12 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
|
||||
// Configure options on the SingleHostReverseProxy
|
||||
proxy.FlushInterval = *upstream.FlushInterval
|
||||
if upstream.FlushInterval != nil {
|
||||
proxy.FlushInterval = *upstream.FlushInterval
|
||||
} else {
|
||||
proxy.FlushInterval = 1 * time.Second
|
||||
}
|
||||
|
||||
if upstream.InsecureSkipTLSVerify {
|
||||
proxy.Transport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
|
417
pkg/upstream/http_test.go
Normal file
417
pkg/upstream/http_test.go
Normal file
@ -0,0 +1,417 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
var _ = Describe("HTTP Upstream Suite", func() {
|
||||
|
||||
const flushInterval5s = 5 * time.Second
|
||||
const flushInterval1s = 1 * time.Second
|
||||
|
||||
type httpUpstreamTableInput struct {
|
||||
id string
|
||||
serverAddr *string
|
||||
target string
|
||||
method string
|
||||
body []byte
|
||||
signatureData *options.SignatureData
|
||||
existingHeaders map[string]string
|
||||
expectedResponse testHTTPResponse
|
||||
errorHandler ProxyErrorHandler
|
||||
}
|
||||
|
||||
DescribeTable("HTTP Upstream ServeHTTP",
|
||||
func(in *httpUpstreamTableInput) {
|
||||
buf := bytes.NewBuffer(in.body)
|
||||
req := httptest.NewRequest(in.method, in.target, buf)
|
||||
// Don't mock the remote Address
|
||||
req.RemoteAddr = ""
|
||||
|
||||
for key, value := range in.existingHeaders {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
flush := 1 * time.Second
|
||||
upstream := options.Upstream{
|
||||
ID: in.id,
|
||||
PassHostHeader: true,
|
||||
ProxyWebSockets: false,
|
||||
InsecureSkipTLSVerify: false,
|
||||
FlushInterval: &flush,
|
||||
}
|
||||
|
||||
Expect(in.serverAddr).ToNot(BeNil())
|
||||
u, err := url.Parse(*in.serverAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
handler := newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler)
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
Expect(rw.Code).To(Equal(in.expectedResponse.code))
|
||||
|
||||
// Delete extra headers that aren't relevant to tests
|
||||
testSanitizeResponseHeader(rw.Header())
|
||||
Expect(rw.Header()).To(Equal(in.expectedResponse.header))
|
||||
|
||||
body := rw.Body.Bytes()
|
||||
if in.expectedResponse.raw != "" || rw.Code != http.StatusOK {
|
||||
Expect(string(body)).To(Equal(in.expectedResponse.raw))
|
||||
return
|
||||
}
|
||||
|
||||
// Compare the reflected request to the upstream
|
||||
request := testHTTPRequest{}
|
||||
Expect(json.Unmarshal(body, &request)).To(Succeed())
|
||||
testSanitizeRequestHeader(request.Header)
|
||||
Expect(request).To(Equal(in.expectedResponse.request))
|
||||
},
|
||||
Entry("request a path on the server", &httpUpstreamTableInput{
|
||||
id: "default",
|
||||
serverAddr: &serverAddr,
|
||||
target: "http://example.localhost/foo",
|
||||
method: "GET",
|
||||
body: []byte{},
|
||||
errorHandler: nil,
|
||||
expectedResponse: testHTTPResponse{
|
||||
code: 200,
|
||||
header: map[string][]string{
|
||||
gapUpstream: {"default"},
|
||||
contentType: {applicationJSON},
|
||||
},
|
||||
request: testHTTPRequest{
|
||||
Method: "GET",
|
||||
URL: "http://example.localhost/foo",
|
||||
Header: map[string][]string{},
|
||||
Body: []byte{},
|
||||
Host: "example.localhost",
|
||||
RequestURI: "http://example.localhost/foo",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Entry("request a path with encoded slashes", &httpUpstreamTableInput{
|
||||
id: "encodedSlashes",
|
||||
serverAddr: &serverAddr,
|
||||
target: "http://example.localhost/foo%2fbar/?baz=1",
|
||||
method: "GET",
|
||||
body: []byte{},
|
||||
errorHandler: nil,
|
||||
expectedResponse: testHTTPResponse{
|
||||
code: 200,
|
||||
header: map[string][]string{
|
||||
gapUpstream: {"encodedSlashes"},
|
||||
contentType: {applicationJSON},
|
||||
},
|
||||
request: testHTTPRequest{
|
||||
Method: "GET",
|
||||
URL: "http://example.localhost/foo%2fbar/?baz=1",
|
||||
Header: map[string][]string{},
|
||||
Body: []byte{},
|
||||
Host: "example.localhost",
|
||||
RequestURI: "http://example.localhost/foo%2fbar/?baz=1",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Entry("when the request has a body", &httpUpstreamTableInput{
|
||||
id: "requestWithBody",
|
||||
serverAddr: &serverAddr,
|
||||
target: "http://example.localhost/withBody",
|
||||
method: "POST",
|
||||
body: []byte("body"),
|
||||
errorHandler: nil,
|
||||
expectedResponse: testHTTPResponse{
|
||||
code: 200,
|
||||
header: map[string][]string{
|
||||
gapUpstream: {"requestWithBody"},
|
||||
contentType: {applicationJSON},
|
||||
},
|
||||
request: testHTTPRequest{
|
||||
Method: "POST",
|
||||
URL: "http://example.localhost/withBody",
|
||||
Header: map[string][]string{
|
||||
contentLength: {"4"},
|
||||
},
|
||||
Body: []byte("body"),
|
||||
Host: "example.localhost",
|
||||
RequestURI: "http://example.localhost/withBody",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Entry("when the upstream is unavailable", &httpUpstreamTableInput{
|
||||
id: "unavailableUpstream",
|
||||
serverAddr: &invalidServer,
|
||||
target: "http://example.localhost/unavailableUpstream",
|
||||
method: "GET",
|
||||
body: []byte{},
|
||||
errorHandler: nil,
|
||||
expectedResponse: testHTTPResponse{
|
||||
code: 502,
|
||||
header: map[string][]string{
|
||||
gapUpstream: {"unavailableUpstream"},
|
||||
},
|
||||
request: testHTTPRequest{},
|
||||
},
|
||||
}),
|
||||
Entry("when the upstream is unavailable and an error handler is set", &httpUpstreamTableInput{
|
||||
id: "withErrorHandler",
|
||||
serverAddr: &invalidServer,
|
||||
target: "http://example.localhost/withErrorHandler",
|
||||
method: "GET",
|
||||
body: []byte{},
|
||||
errorHandler: func(rw http.ResponseWriter, _ *http.Request, _ error) {
|
||||
rw.WriteHeader(502)
|
||||
rw.Write([]byte("error"))
|
||||
},
|
||||
expectedResponse: testHTTPResponse{
|
||||
code: 502,
|
||||
header: map[string][]string{
|
||||
gapUpstream: {"withErrorHandler"},
|
||||
},
|
||||
raw: "error",
|
||||
request: testHTTPRequest{},
|
||||
},
|
||||
}),
|
||||
Entry("with a signature", &httpUpstreamTableInput{
|
||||
id: "withSignature",
|
||||
serverAddr: &serverAddr,
|
||||
target: "http://example.localhost/withSignature",
|
||||
method: "GET",
|
||||
body: []byte{},
|
||||
signatureData: &options.SignatureData{
|
||||
Hash: crypto.SHA256,
|
||||
Key: "key",
|
||||
},
|
||||
errorHandler: nil,
|
||||
expectedResponse: testHTTPResponse{
|
||||
code: 200,
|
||||
header: map[string][]string{
|
||||
contentType: {applicationJSON},
|
||||
gapUpstream: {"withSignature"},
|
||||
},
|
||||
request: testHTTPRequest{
|
||||
Method: "GET",
|
||||
URL: "http://example.localhost/withSignature",
|
||||
Header: map[string][]string{
|
||||
gapAuth: {""},
|
||||
gapSignature: {"sha256 osMWI8Rr0Zr5HgNq6wakrgJITVJQMmFN1fXCesrqrmM="},
|
||||
},
|
||||
Body: []byte{},
|
||||
Host: "example.localhost",
|
||||
RequestURI: "http://example.localhost/withSignature",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Entry("with existing headers", &httpUpstreamTableInput{
|
||||
id: "existingHeaders",
|
||||
serverAddr: &serverAddr,
|
||||
target: "http://example.localhost/existingHeaders",
|
||||
method: "GET",
|
||||
body: []byte{},
|
||||
errorHandler: nil,
|
||||
existingHeaders: map[string]string{
|
||||
"Header1": "value1",
|
||||
"Header2": "value2",
|
||||
},
|
||||
expectedResponse: testHTTPResponse{
|
||||
code: 200,
|
||||
header: map[string][]string{
|
||||
gapUpstream: {"existingHeaders"},
|
||||
contentType: {applicationJSON},
|
||||
},
|
||||
request: testHTTPRequest{
|
||||
Method: "GET",
|
||||
URL: "http://example.localhost/existingHeaders",
|
||||
Header: map[string][]string{
|
||||
"Header1": {"value1"},
|
||||
"Header2": {"value2"},
|
||||
},
|
||||
Body: []byte{},
|
||||
Host: "example.localhost",
|
||||
RequestURI: "http://example.localhost/existingHeaders",
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
It("ServeHTTP, when not passing a host header", func() {
|
||||
req := httptest.NewRequest("", "http://example.localhost/foo", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
flush := 1 * time.Second
|
||||
upstream := options.Upstream{
|
||||
ID: "noPassHost",
|
||||
PassHostHeader: false,
|
||||
ProxyWebSockets: false,
|
||||
InsecureSkipTLSVerify: false,
|
||||
FlushInterval: &flush,
|
||||
}
|
||||
|
||||
u, err := url.Parse(serverAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
handler := newHTTPUpstreamProxy(upstream, u, nil, nil)
|
||||
httpUpstream, ok := handler.(*httpUpstreamProxy)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
||||
// Override the handler to just run the director and not actually send the request
|
||||
requestInterceptor := func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
|
||||
proxy, ok := h.(*httputil.ReverseProxy)
|
||||
Expect(ok).To(BeTrue())
|
||||
proxy.Director(req)
|
||||
})
|
||||
}
|
||||
httpUpstream.handler = requestInterceptor(httpUpstream.handler)
|
||||
|
||||
httpUpstream.ServeHTTP(rw, req)
|
||||
Expect(req.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://")))
|
||||
})
|
||||
|
||||
type newUpstreamTableInput struct {
|
||||
proxyWebSockets bool
|
||||
flushInterval time.Duration
|
||||
skipVerify bool
|
||||
sigData *options.SignatureData
|
||||
errorHandler func(http.ResponseWriter, *http.Request, error)
|
||||
}
|
||||
|
||||
DescribeTable("newHTTPUpstreamProxy",
|
||||
func(in *newUpstreamTableInput) {
|
||||
u, err := url.Parse("http://upstream:1234")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
upstream := options.Upstream{
|
||||
ID: "foo123",
|
||||
FlushInterval: &in.flushInterval,
|
||||
InsecureSkipTLSVerify: in.skipVerify,
|
||||
ProxyWebSockets: in.proxyWebSockets,
|
||||
}
|
||||
|
||||
handler := newHTTPUpstreamProxy(upstream, u, in.sigData, in.errorHandler)
|
||||
upstreamProxy, ok := handler.(*httpUpstreamProxy)
|
||||
Expect(ok).To(BeTrue())
|
||||
|
||||
Expect(upstreamProxy.auth != nil).To(Equal(in.sigData != nil))
|
||||
Expect(upstreamProxy.wsHandler != nil).To(Equal(in.proxyWebSockets))
|
||||
Expect(upstreamProxy.upstream).To(Equal(upstream.ID))
|
||||
Expect(upstreamProxy.handler).ToNot(BeNil())
|
||||
|
||||
proxy, ok := upstreamProxy.handler.(*httputil.ReverseProxy)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(proxy.FlushInterval).To(Equal(in.flushInterval))
|
||||
Expect(proxy.ErrorHandler != nil).To(Equal(in.errorHandler != nil))
|
||||
if in.skipVerify {
|
||||
Expect(proxy.Transport).To(Equal(&http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}))
|
||||
}
|
||||
},
|
||||
Entry("with proxy websockets", &newUpstreamTableInput{
|
||||
proxyWebSockets: true,
|
||||
flushInterval: flushInterval1s,
|
||||
skipVerify: false,
|
||||
sigData: nil,
|
||||
errorHandler: nil,
|
||||
}),
|
||||
Entry("with a non standard flush interval", &newUpstreamTableInput{
|
||||
proxyWebSockets: false,
|
||||
flushInterval: flushInterval5s,
|
||||
skipVerify: false,
|
||||
sigData: nil,
|
||||
errorHandler: nil,
|
||||
}),
|
||||
Entry("with a InsecureSkipTLSVerify", &newUpstreamTableInput{
|
||||
proxyWebSockets: false,
|
||||
flushInterval: flushInterval1s,
|
||||
skipVerify: true,
|
||||
sigData: nil,
|
||||
errorHandler: nil,
|
||||
}),
|
||||
Entry("with a SignatureData", &newUpstreamTableInput{
|
||||
proxyWebSockets: false,
|
||||
flushInterval: flushInterval1s,
|
||||
skipVerify: false,
|
||||
sigData: &options.SignatureData{Hash: crypto.SHA256, Key: "secret"},
|
||||
errorHandler: nil,
|
||||
}),
|
||||
Entry("with an error handler", &newUpstreamTableInput{
|
||||
proxyWebSockets: false,
|
||||
flushInterval: flushInterval1s,
|
||||
skipVerify: false,
|
||||
sigData: nil,
|
||||
errorHandler: func(rw http.ResponseWriter, req *http.Request, arg3 error) {
|
||||
rw.WriteHeader(502)
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
Context("with a websocket proxy", func() {
|
||||
var proxyServer *httptest.Server
|
||||
|
||||
BeforeEach(func() {
|
||||
flush := 1 * time.Second
|
||||
upstream := options.Upstream{
|
||||
ID: "websocketProxy",
|
||||
PassHostHeader: true,
|
||||
ProxyWebSockets: true,
|
||||
InsecureSkipTLSVerify: false,
|
||||
FlushInterval: &flush,
|
||||
}
|
||||
|
||||
u, err := url.Parse(serverAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
handler := newHTTPUpstreamProxy(upstream, u, nil, nil)
|
||||
proxyServer = httptest.NewServer(handler)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
proxyServer.Close()
|
||||
})
|
||||
|
||||
It("will proxy websockets", func() {
|
||||
origin := "http://example.localhost"
|
||||
message := "Hello, world!"
|
||||
|
||||
proxyURL, err := url.Parse(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
wsAddr := fmt.Sprintf("ws://%s/", proxyURL.Host)
|
||||
ws, err := websocket.Dial(wsAddr, "", origin)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(websocket.Message.Send(ws, []byte(message))).To(Succeed())
|
||||
var response testWebSocketResponse
|
||||
Expect(websocket.JSON.Receive(ws, &response)).To(Succeed())
|
||||
Expect(response).To(Equal(testWebSocketResponse{
|
||||
Message: message,
|
||||
Origin: origin,
|
||||
}))
|
||||
})
|
||||
|
||||
It("will proxy HTTP requests", func() {
|
||||
response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response.StatusCode).To(Equal(200))
|
||||
Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy"))
|
||||
})
|
||||
})
|
||||
})
|
182
pkg/upstream/proxy_test.go
Normal file
182
pkg/upstream/proxy_test.go
Normal file
@ -0,0 +1,182 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Proxy Suite", func() {
|
||||
var upstreamServer http.Handler
|
||||
|
||||
BeforeEach(func() {
|
||||
sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"}
|
||||
|
||||
tmpl, err := template.New("").Parse("{{ .Title }}\n{{ .Message }}\n{{ .ProxyPrefix }}")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
errorHandler := NewProxyErrorHandler(tmpl, "prefix")
|
||||
|
||||
ok := http.StatusOK
|
||||
|
||||
upstreams := options.Upstreams{
|
||||
{
|
||||
ID: "http-backend",
|
||||
Path: "/http/",
|
||||
URI: serverAddr,
|
||||
},
|
||||
{
|
||||
ID: "file-backend",
|
||||
Path: "/files/",
|
||||
URI: fmt.Sprintf("file:///%s", filesDir),
|
||||
},
|
||||
{
|
||||
ID: "static-backend",
|
||||
Path: "/static/",
|
||||
Static: true,
|
||||
StaticCode: &ok,
|
||||
},
|
||||
{
|
||||
ID: "bad-http-backend",
|
||||
Path: "/bad-http/",
|
||||
URI: "http://::1",
|
||||
},
|
||||
{
|
||||
ID: "single-path-backend",
|
||||
Path: "/single-path",
|
||||
Static: true,
|
||||
StaticCode: &ok,
|
||||
},
|
||||
}
|
||||
|
||||
upstreamServer, err = NewProxy(upstreams, sigData, errorHandler)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
type proxyTableInput struct {
|
||||
target string
|
||||
response testHTTPResponse
|
||||
}
|
||||
|
||||
DescribeTable("Proxy ServerHTTP",
|
||||
func(in *proxyTableInput) {
|
||||
req := httptest.NewRequest("", in.target, nil)
|
||||
rw := httptest.NewRecorder()
|
||||
// Don't mock the remote Address
|
||||
req.RemoteAddr = ""
|
||||
|
||||
upstreamServer.ServeHTTP(rw, req)
|
||||
|
||||
Expect(rw.Code).To(Equal(in.response.code))
|
||||
|
||||
// Delete extra headers that aren't relevant to tests
|
||||
testSanitizeResponseHeader(rw.Header())
|
||||
Expect(rw.Header()).To(Equal(in.response.header))
|
||||
|
||||
body := rw.Body.Bytes()
|
||||
// If the raw body is set, check that, else check the Request object
|
||||
if in.response.raw != "" {
|
||||
Expect(string(body)).To(Equal(in.response.raw))
|
||||
return
|
||||
}
|
||||
|
||||
// Compare the reflected request to the upstream
|
||||
request := testHTTPRequest{}
|
||||
Expect(json.Unmarshal(body, &request)).To(Succeed())
|
||||
testSanitizeRequestHeader(request.Header)
|
||||
Expect(request).To(Equal(in.response.request))
|
||||
},
|
||||
Entry("with a request to the HTTP service", &proxyTableInput{
|
||||
target: "http://example.localhost/http/1234",
|
||||
response: testHTTPResponse{
|
||||
code: 200,
|
||||
header: map[string][]string{
|
||||
gapUpstream: {"http-backend"},
|
||||
contentType: {applicationJSON},
|
||||
},
|
||||
request: testHTTPRequest{
|
||||
Method: "GET",
|
||||
URL: "http://example.localhost/http/1234",
|
||||
Header: map[string][]string{
|
||||
"Gap-Auth": {""},
|
||||
"Gap-Signature": {"sha256 ofB1u6+FhEUbFLc3/uGbJVkl7GaN4egFqVvyO3+2I1w="},
|
||||
},
|
||||
Body: []byte{},
|
||||
Host: "example.localhost",
|
||||
RequestURI: "http://example.localhost/http/1234",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Entry("with a request to the File backend", &proxyTableInput{
|
||||
target: "http://example.localhost/files/foo",
|
||||
response: testHTTPResponse{
|
||||
code: 200,
|
||||
header: map[string][]string{
|
||||
contentType: {textPlainUTF8},
|
||||
gapUpstream: {"file-backend"},
|
||||
},
|
||||
raw: "foo",
|
||||
},
|
||||
}),
|
||||
Entry("with a request to the Static backend", &proxyTableInput{
|
||||
target: "http://example.localhost/static/bar",
|
||||
response: testHTTPResponse{
|
||||
code: 200,
|
||||
header: map[string][]string{
|
||||
gapUpstream: {"static-backend"},
|
||||
},
|
||||
raw: "Authenticated",
|
||||
},
|
||||
}),
|
||||
Entry("with a request to the bad HTTP backend", &proxyTableInput{
|
||||
target: "http://example.localhost/bad-http/bad",
|
||||
response: testHTTPResponse{
|
||||
code: 502,
|
||||
header: map[string][]string{
|
||||
gapUpstream: {"bad-http-backend"},
|
||||
},
|
||||
// This tests the error handler
|
||||
raw: "Bad Gateway\nError proxying to upstream server\nprefix",
|
||||
},
|
||||
}),
|
||||
Entry("with a request to the to an unregistered path", &proxyTableInput{
|
||||
target: "http://example.localhost/unregistered",
|
||||
response: testHTTPResponse{
|
||||
code: 404,
|
||||
header: map[string][]string{
|
||||
"X-Content-Type-Options": {"nosniff"},
|
||||
contentType: {textPlainUTF8},
|
||||
},
|
||||
raw: "404 page not found\n",
|
||||
},
|
||||
}),
|
||||
Entry("with a request to the to backend registered to a single path", &proxyTableInput{
|
||||
target: "http://example.localhost/single-path",
|
||||
response: testHTTPResponse{
|
||||
code: 200,
|
||||
header: map[string][]string{
|
||||
gapUpstream: {"single-path-backend"},
|
||||
},
|
||||
raw: "Authenticated",
|
||||
},
|
||||
}),
|
||||
Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{
|
||||
target: "http://example.localhost/single-path/unregistered",
|
||||
response: testHTTPResponse{
|
||||
code: 404,
|
||||
header: map[string][]string{
|
||||
"X-Content-Type-Options": {"nosniff"},
|
||||
contentType: {textPlainUTF8},
|
||||
},
|
||||
raw: "404 page not found\n",
|
||||
},
|
||||
}),
|
||||
)
|
||||
})
|
81
pkg/upstream/static_test.go
Normal file
81
pkg/upstream/static_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Static Response Suite", func() {
|
||||
const authenticated = "Authenticated"
|
||||
var id string
|
||||
|
||||
BeforeEach(func() {
|
||||
// Generate a random id before each test to check the GAP-Upstream-Address
|
||||
// is being set correctly
|
||||
idBytes := make([]byte, 16)
|
||||
_, err := io.ReadFull(rand.Reader, idBytes)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
id = string(idBytes)
|
||||
})
|
||||
|
||||
type serveHTTPTableInput struct {
|
||||
requestPath string
|
||||
staticCode int
|
||||
expectedBody string
|
||||
expectedCode int
|
||||
}
|
||||
|
||||
DescribeTable("staticResponse ServeHTTP",
|
||||
func(in *serveHTTPTableInput) {
|
||||
var code *int
|
||||
if in.staticCode != 0 {
|
||||
code = &in.staticCode
|
||||
}
|
||||
handler := newStaticResponseHandler(id, code)
|
||||
|
||||
req := httptest.NewRequest("", in.requestPath, nil)
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id))
|
||||
Expect(rw.Code).To(Equal(in.expectedCode))
|
||||
Expect(rw.Body.String()).To(Equal(in.expectedBody))
|
||||
},
|
||||
Entry("with no given code", &serveHTTPTableInput{
|
||||
requestPath: "/",
|
||||
staticCode: 0, // Placeholder for nil
|
||||
expectedBody: authenticated,
|
||||
expectedCode: http.StatusOK,
|
||||
}),
|
||||
Entry("with status OK", &serveHTTPTableInput{
|
||||
requestPath: "/abc",
|
||||
staticCode: http.StatusOK,
|
||||
expectedBody: authenticated,
|
||||
expectedCode: http.StatusOK,
|
||||
}),
|
||||
Entry("with status NoContent", &serveHTTPTableInput{
|
||||
requestPath: "/def",
|
||||
staticCode: http.StatusNoContent,
|
||||
expectedBody: authenticated,
|
||||
expectedCode: http.StatusNoContent,
|
||||
}),
|
||||
Entry("with status NotFound", &serveHTTPTableInput{
|
||||
requestPath: "/ghi",
|
||||
staticCode: http.StatusNotFound,
|
||||
expectedBody: authenticated,
|
||||
expectedCode: http.StatusNotFound,
|
||||
}),
|
||||
Entry("with status Teapot", &serveHTTPTableInput{
|
||||
requestPath: "/jkl",
|
||||
staticCode: http.StatusTeapot,
|
||||
expectedBody: authenticated,
|
||||
expectedCode: http.StatusTeapot,
|
||||
}),
|
||||
)
|
||||
})
|
180
pkg/upstream/upstream_suite_test.go
Normal file
180
pkg/upstream/upstream_suite_test.go
Normal file
@ -0,0 +1,180 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
filesDir string
|
||||
server *httptest.Server
|
||||
serverAddr string
|
||||
invalidServer = "http://::1"
|
||||
)
|
||||
|
||||
func TestUpstreamSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
log.SetOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Upstream Suite")
|
||||
}
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
// Set up files for serving via file servers
|
||||
dir, err := ioutil.TempDir("", "oauth2-proxy-upstream-suite")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(ioutil.WriteFile(path.Join(dir, "foo"), []byte("foo"), 0644)).To(Succeed())
|
||||
Expect(ioutil.WriteFile(path.Join(dir, "bar"), []byte("bar"), 0644)).To(Succeed())
|
||||
Expect(os.Mkdir(path.Join(dir, "subdir"), os.ModePerm)).To(Succeed())
|
||||
Expect(ioutil.WriteFile(path.Join(dir, "subdir", "baz"), []byte("baz"), 0644)).To(Succeed())
|
||||
filesDir = dir
|
||||
|
||||
// Set up a webserver that reflects requests
|
||||
server = httptest.NewServer(&testHTTPUpstream{})
|
||||
serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String())
|
||||
})
|
||||
|
||||
var _ = AfterSuite(func() {
|
||||
server.Close()
|
||||
Expect(os.RemoveAll(filesDir)).To(Succeed())
|
||||
})
|
||||
|
||||
const (
|
||||
contentType = "Content-Type"
|
||||
contentLength = "Content-Length"
|
||||
acceptEncoding = "Accept-Encoding"
|
||||
applicationJSON = "application/json"
|
||||
textPlainUTF8 = "text/plain; charset=utf-8"
|
||||
gapUpstream = "Gap-Upstream-Address"
|
||||
gapAuth = "Gap-Auth"
|
||||
gapSignature = "Gap-Signature"
|
||||
)
|
||||
|
||||
// testHTTPResponse is a struct used for checking responses in table tests
|
||||
type testHTTPResponse struct {
|
||||
code int
|
||||
header http.Header
|
||||
raw string
|
||||
request testHTTPRequest
|
||||
}
|
||||
|
||||
// testHTTPRequest is a struct used to capture the state of a request made to
|
||||
// an upstream during a test
|
||||
type testHTTPRequest struct {
|
||||
Method string
|
||||
URL string
|
||||
Header http.Header
|
||||
Body []byte
|
||||
Host string
|
||||
RequestURI string
|
||||
}
|
||||
|
||||
type testWebSocketResponse struct {
|
||||
Message string
|
||||
Origin string
|
||||
}
|
||||
|
||||
type testHTTPUpstream struct{}
|
||||
|
||||
func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Header.Get("Upgrade") == "websocket" {
|
||||
t.websocketHandler().ServeHTTP(rw, req)
|
||||
} else {
|
||||
t.serveHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *testHTTPUpstream) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
request, err := toTestHTTPRequest(req)
|
||||
if err != nil {
|
||||
t.writeError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
t.writeError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Write(data)
|
||||
}
|
||||
|
||||
func (t *testHTTPUpstream) websocketHandler() http.Handler {
|
||||
return websocket.Handler(func(ws *websocket.Conn) {
|
||||
defer ws.Close()
|
||||
var data []byte
|
||||
err := websocket.Message.Receive(ws, &data)
|
||||
if err != nil {
|
||||
websocket.Message.Send(ws, []byte(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
wsResponse := testWebSocketResponse{
|
||||
Message: string(data),
|
||||
Origin: ws.Request().Header.Get("Origin"),
|
||||
}
|
||||
err = websocket.JSON.Send(ws, wsResponse)
|
||||
if err != nil {
|
||||
websocket.Message.Send(ws, []byte(err.Error()))
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) {
|
||||
rw.WriteHeader(500)
|
||||
if err != nil {
|
||||
rw.Write([]byte(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) {
|
||||
requestBody := []byte{}
|
||||
if req.Body != http.NoBody {
|
||||
var err error
|
||||
requestBody, err = ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return testHTTPRequest{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return testHTTPRequest{
|
||||
Method: req.Method,
|
||||
URL: req.URL.String(),
|
||||
Header: req.Header,
|
||||
Body: requestBody,
|
||||
Host: req.Host,
|
||||
RequestURI: req.RequestURI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// String headers added to the response that we do not want to test
|
||||
func testSanitizeResponseHeader(h http.Header) {
|
||||
// From HTTP responses
|
||||
h.Del("Date")
|
||||
h.Del(contentLength)
|
||||
|
||||
// From File responses
|
||||
h.Del("Accept-Ranges")
|
||||
h.Del("Last-Modified")
|
||||
}
|
||||
|
||||
// Strip the accept header that is added by the HTTP Transport
|
||||
func testSanitizeRequestHeader(h http.Header) {
|
||||
h.Del(acceptEncoding)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user