1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2024-11-24 08:52:25 +02:00

Adds failing test for using upstream Host header.

This commit is contained in:
John Boxall 2014-11-30 17:12:33 -08:00 committed by Jehiah Czebotar
parent ade9502dd2
commit 20a152261c
2 changed files with 41 additions and 1 deletions

View File

@ -46,6 +46,10 @@ type OauthProxy struct {
compiledRegex []*regexp.Regexp
}
func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) {
return httputil.NewSingleHostReverseProxy(target)
}
func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
login, _ := url.Parse("https://accounts.google.com/o/oauth2/auth")
redeem, _ := url.Parse("https://accounts.google.com/o/oauth2/token")
@ -54,7 +58,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
path := u.Path
u.Path = ""
log.Printf("mapping path %q => upstream %q", path, u)
serveMux.Handle(path, httputil.NewSingleHostReverseProxy(u))
serveMux.Handle(path, NewReverseProxy(u))
}
for _, u := range opts.CompiledRegex {
log.Printf("compiled skip-auth-regex => %q", u)

36
oauthproxy_test.go Normal file
View File

@ -0,0 +1,36 @@
package main
import (
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
func TestNewReverseProxy(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
hostname, _, _ := net.SplitHostPort(r.Host)
w.Write([]byte(hostname))
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
backendHostname := "upstream.127.0.0.1.xip.io"
_, backendPort, _ := net.SplitHostPort(backendURL.Host)
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
proxyHandler := NewReverseProxy(proxyURL)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}