1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2024-11-28 09:08:44 +02:00
oauth2-proxy/oauthproxy_test.go

465 lines
13 KiB
Go
Raw Normal View History

package main
import (
2015-05-21 08:50:21 +02:00
"github.com/bitly/oauth2_proxy/providers"
"github.com/bmizerany/assert"
"io/ioutil"
2015-05-21 05:23:48 +02:00
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
"time"
)
2015-05-21 05:23:48 +02:00
func init() {
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
}
func TestNewReverseProxy(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
2015-03-17 21:15:15 +02:00
hostname, _, _ := net.SplitHostPort(r.Host)
w.Write([]byte(hostname))
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
proxyHandler := NewReverseProxy(proxyURL)
2015-03-17 21:15:15 +02:00
setProxyUpstreamHostHeader(proxyHandler, proxyURL)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
func TestEncodedSlashes(t *testing.T) {
var seen string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
seen = r.RequestURI
}))
defer backend.Close()
b, _ := url.Parse(backend.URL)
proxyHandler := NewReverseProxy(b)
setProxyDirector(proxyHandler)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
f, _ := url.Parse(frontend.URL)
2015-03-21 21:29:07 +02:00
encodedPath := "/a%2Fb/?c=1"
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}}
_, err := http.DefaultClient.Do(getReq)
if err != nil {
t.Fatalf("err %s", err)
}
2015-03-21 21:29:07 +02:00
if seen != encodedPath {
t.Errorf("got bad request %q expected %q", seen, encodedPath)
}
}
func TestRobotsTxt(t *testing.T) {
opts := NewOptions()
opts.Upstreams = append(opts.Upstreams, "unused")
opts.ClientID = "bazquux"
opts.ClientSecret = "foobar"
opts.CookieSecret = "xyzzyplugh"
opts.Validate()
proxy := NewOauthProxy(opts, func(string) bool { return true })
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/robots.txt", nil)
proxy.ServeHTTP(rw, req)
assert.Equal(t, 200, rw.Code)
assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String())
}
type TestProvider struct {
*providers.ProviderData
EmailAddress string
2015-05-13 03:48:13 +02:00
ValidToken bool
}
func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) {
return tp.EmailAddress, nil
}
func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool {
2015-05-13 03:48:13 +02:00
return tp.ValidToken
}
type PassAccessTokenTest struct {
provider_server *httptest.Server
proxy *OauthProxy
opts *Options
}
type PassAccessTokenTestOptions struct {
PassAccessToken bool
}
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
t := &PassAccessTokenTest{}
t.provider_server = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2015-05-21 05:23:48 +02:00
log.Printf("%#v", r)
url := r.URL
payload := ""
switch url.Path {
case "/oauth/token":
payload = `{"access_token": "my_auth_token"}`
default:
2015-05-21 05:23:48 +02:00
payload = r.Header.Get("X-Forwarded-Access-Token")
if payload == "" {
payload = "No access token found."
}
}
w.WriteHeader(200)
w.Write([]byte(payload))
}))
t.opts = NewOptions()
t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL)
// The CookieSecret must be 32 bytes in order to create the AES
// cipher.
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
t.opts.ClientID = "bazquux"
t.opts.ClientSecret = "foobar"
t.opts.CookieSecure = false
t.opts.PassAccessToken = opts.PassAccessToken
t.opts.Validate()
provider_url, _ := url.Parse(t.provider_server.URL)
const email_address = "michael.bland@gsa.gov"
t.opts.provider = &TestProvider{
ProviderData: &providers.ProviderData{
ProviderName: "Test Provider",
LoginUrl: &url.URL{
Scheme: "http",
Host: provider_url.Host,
Path: "/oauth/authorize",
},
RedeemUrl: &url.URL{
Scheme: "http",
Host: provider_url.Host,
Path: "/oauth/token",
},
ProfileUrl: &url.URL{
Scheme: "http",
Host: provider_url.Host,
Path: "/api/v1/profile",
},
Scope: "profile.email",
},
EmailAddress: email_address,
}
t.proxy = NewOauthProxy(t.opts, func(email string) bool {
return email == email_address
})
return t
}
func (pat_test *PassAccessTokenTest) Close() {
pat_test.provider_server.Close()
}
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
cookie string) {
rw := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
strings.NewReader(""))
if err != nil {
return 0, ""
}
pat_test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.HeaderMap["Set-Cookie"][0]
}
2015-05-21 05:23:48 +02:00
func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
cookieName := pat_test.proxy.CookieName
var value string
key_prefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, key_prefix)
if value != field {
break
} else {
value = ""
}
}
if value == "" {
return 0, ""
}
req, err := http.NewRequest("GET", "/", strings.NewReader(""))
if err != nil {
return 0, ""
}
req.AddCookie(&http.Cookie{
Name: cookieName,
Value: value,
Path: "/",
Expires: time.Now().Add(time.Duration(24)),
HttpOnly: true,
})
rw := httptest.NewRecorder()
pat_test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
func TestForwardAccessTokenUpstream(t *testing.T) {
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true,
})
defer pat_test.Close()
// A successful validation will redirect and set the auth cookie.
code, cookie := pat_test.getCallbackEndpoint()
assert.Equal(t, 302, code)
assert.NotEqual(t, nil, cookie)
// Now we make a regular request; the access_token from the cookie is
// forwarded as the "X-Forwarded-Access-Token" header. The token is
// read by the test provider server and written in the response body.
code, payload := pat_test.getRootEndpoint(cookie)
assert.Equal(t, 200, code)
assert.Equal(t, "my_auth_token", payload)
}
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: false,
})
defer pat_test.Close()
// A successful validation will redirect and set the auth cookie.
code, cookie := pat_test.getCallbackEndpoint()
assert.Equal(t, 302, code)
assert.NotEqual(t, nil, cookie)
// Now we make a regular request, but the access token header should
// not be present.
code, payload := pat_test.getRootEndpoint(cookie)
assert.Equal(t, 200, code)
assert.Equal(t, "No access token found.", payload)
}
type SignInPageTest struct {
opts *Options
proxy *OauthProxy
sign_in_regexp *regexp.Regexp
}
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
func NewSignInPageTest() *SignInPageTest {
var sip_test SignInPageTest
sip_test.opts = NewOptions()
sip_test.opts.Upstreams = append(sip_test.opts.Upstreams, "unused")
sip_test.opts.CookieSecret = "foobar"
sip_test.opts.ClientID = "bazquux"
sip_test.opts.ClientSecret = "xyzzyplugh"
sip_test.opts.Validate()
sip_test.proxy = NewOauthProxy(sip_test.opts, func(email string) bool {
return true
})
sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern)
return &sip_test
}
func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
sip_test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
sip_test := NewSignInPageTest()
const endpoint = "/some/random/endpoint"
code, body := sip_test.GetEndpoint(endpoint)
assert.Equal(t, 403, code)
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body)
}
if match[1] != endpoint {
t.Fatal(`expected redirect to "` + endpoint +
`", but was "` + match[1] + `"`)
}
}
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
sip_test := NewSignInPageTest()
code, body := sip_test.GetEndpoint("/oauth2/sign_in")
assert.Equal(t, 200, code)
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body)
}
if match[1] != "/" {
t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
}
}
2015-05-08 17:52:03 +02:00
type ProcessCookieTest struct {
opts *Options
proxy *OauthProxy
rw *httptest.ResponseRecorder
req *http.Request
2015-05-13 03:48:13 +02:00
provider TestProvider
response_code int
2015-05-09 22:48:39 +02:00
validate_user bool
2015-05-08 17:52:03 +02:00
}
2015-05-13 03:48:13 +02:00
type ProcessCookieTestOpts struct {
provider_validate_cookie_response bool
}
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
2015-05-08 17:52:03 +02:00
var pc_test ProcessCookieTest
pc_test.opts = NewOptions()
pc_test.opts.Upstreams = append(pc_test.opts.Upstreams, "unused")
pc_test.opts.ClientID = "bazquux"
pc_test.opts.ClientSecret = "xyzzyplugh"
pc_test.opts.CookieSecret = "0123456789abcdef"
// First, set the CookieRefresh option so proxy.AesCipher is created,
// needed to encrypt the access_token.
pc_test.opts.CookieRefresh = time.Hour
2015-05-08 17:52:03 +02:00
pc_test.opts.Validate()
pc_test.proxy = NewOauthProxy(pc_test.opts, func(email string) bool {
2015-05-09 22:48:39 +02:00
return pc_test.validate_user
2015-05-08 17:52:03 +02:00
})
2015-05-13 03:48:13 +02:00
pc_test.proxy.provider = &TestProvider{
ValidToken: opts.provider_validate_cookie_response,
}
2015-05-08 17:52:03 +02:00
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation.
pc_test.proxy.CookieRefresh = time.Duration(0)
2015-05-08 17:52:03 +02:00
pc_test.rw = httptest.NewRecorder()
pc_test.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
2015-05-09 22:48:39 +02:00
pc_test.validate_user = true
2015-05-08 17:52:03 +02:00
return &pc_test
}
2015-05-13 03:48:13 +02:00
func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
return NewProcessCookieTest(ProcessCookieTestOpts{
provider_validate_cookie_response: true,
})
}
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
return p.proxy.MakeCookie(p.req, value, p.opts.CookieExpire, ref)
2015-05-08 17:52:03 +02:00
}
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
if err != nil {
return err
}
p.req.AddCookie(p.proxy.MakeCookie(p.req, value, p.proxy.CookieExpire, ref))
return nil
2015-05-08 17:52:03 +02:00
}
func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) {
return p.proxy.LoadCookiedSession(p.req)
2015-05-08 17:52:03 +02:00
}
func TestLoadCookiedSession(t *testing.T) {
2015-05-13 03:48:13 +02:00
pc_test := NewProcessCookieTestWithDefaults()
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, time.Now())
session, _, err := pc_test.LoadCookiedSession()
assert.Equal(t, nil, err)
assert.Equal(t, startSession.Email, session.Email)
assert.Equal(t, "michael.bland", session.User)
assert.Equal(t, startSession.AccessToken, session.AccessToken)
2015-05-08 17:52:03 +02:00
}
func TestProcessCookieNoCookieError(t *testing.T) {
2015-05-13 03:48:13 +02:00
pc_test := NewProcessCookieTestWithDefaults()
session, _, err := pc_test.LoadCookiedSession()
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
if session != nil {
t.Errorf("expected nil session. got %#v", session)
}
}
func TestProcessCookieRefreshNotSet(t *testing.T) {
2015-05-13 03:48:13 +02:00
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference)
session, age, err := pc_test.LoadCookiedSession()
assert.Equal(t, nil, err)
if age < time.Duration(-2)*time.Hour {
t.Errorf("cookie too young %v", age)
}
assert.Equal(t, startSession.Email, session.Email)
}
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference)
session, _, err := pc_test.LoadCookiedSession()
assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expected nil session %#v", session)
}
}
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference)
pc_test.proxy.CookieRefresh = time.Hour
session, _, err := pc_test.LoadCookiedSession()
assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expected nil session %#v", session)
}
}