From 20a152261c6712c1ac7e2de8fbc24231be635b92 Mon Sep 17 00:00:00 2001 From: John Boxall Date: Sun, 30 Nov 2014 17:12:33 -0800 Subject: [PATCH] Adds failing test for using upstream Host header. --- oauthproxy.go | 6 +++++- oauthproxy_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 oauthproxy_test.go diff --git a/oauthproxy.go b/oauthproxy.go index 7951a2a5..16c0dadb 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -46,6 +46,10 @@ type OauthProxy struct { compiledRegex []*regexp.Regexp } +func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) { + return httputil.NewSingleHostReverseProxy(target) +} + func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { login, _ := url.Parse("https://accounts.google.com/o/oauth2/auth") redeem, _ := url.Parse("https://accounts.google.com/o/oauth2/token") @@ -54,7 +58,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { path := u.Path u.Path = "" log.Printf("mapping path %q => upstream %q", path, u) - serveMux.Handle(path, httputil.NewSingleHostReverseProxy(u)) + serveMux.Handle(path, NewReverseProxy(u)) } for _, u := range opts.CompiledRegex { log.Printf("compiled skip-auth-regex => %q", u) diff --git a/oauthproxy_test.go b/oauthproxy_test.go new file mode 100644 index 00000000..c4e68b3b --- /dev/null +++ b/oauthproxy_test.go @@ -0,0 +1,36 @@ +package main + +import ( + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestNewReverseProxy(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + hostname, _, _ := net.SplitHostPort(r.Host) + w.Write([]byte(hostname)) + })) + defer backend.Close() + + backendURL, _ := url.Parse(backend.URL) + backendHostname := "upstream.127.0.0.1.xip.io" + _, backendPort, _ := net.SplitHostPort(backendURL.Host) + backendHost := net.JoinHostPort(backendHostname, backendPort) + proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") + + proxyHandler := NewReverseProxy(proxyURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + res, _ := http.DefaultClient.Do(getReq) + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendHostname; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +}