mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-25 22:00:56 +02:00
* Encode sessions with MsgPack + LZ4 Assumes ciphers are now mandatory per #414. Cookie & Redis sessions can fallback to V5 style JSON in error cases. TODO: session_state.go unit tests & new unit tests for Legacy fallback scenarios. * Only compress encoded sessions with Cookie Store * Cleanup msgpack + lz4 error handling * Change NewBase64Cipher to take in an existing Cipher * Add msgpack & lz4 session state tests * Add required options for oauthproxy tests More aggressively assert.NoError on all validation.Validate(opts) calls to enforce legal options in all our tests. Add additional NoError checks wherever error return values were ignored. * Remove support for uncompressed session state fields * Improve error verbosity & add session state tests * Ensure all marshalled sessions are valid Invalid CFB decryptions can result in garbage data that 1/100 times might cause message pack unmarshal to not fail and instead return an empty session. This adds more rigor to make sure legacy sessions cause appropriate errors. * Add tests for legacy V5 session decoding Refactor common legacy JSON test cases to a legacy helpers area under session store tests. * Make ValidateSession a struct method & add CHANGELOG entry * Improve SessionState error & comments verbosity * Move legacy session test helpers to sessions pkg Placing these helpers under the sessions pkg removed all the circular import uses in housing it under the session store area. * Improve SignatureAuthenticator test helper formatting * Make redis.legacyV5DecodeSession internal * Make LegacyV5TestCase test table public for linter
2221 lines
60 KiB
Go
2221 lines
60 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc"
|
|
"github.com/mbland/hmacauth"
|
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
|
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie"
|
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
|
|
"github.com/oauth2-proxy/oauth2-proxy/providers"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/net/websocket"
|
|
)
|
|
|
|
const (
|
|
// The rawCookieSecret is 32 bytes and the base64CookieSecret is the base64
|
|
// encoded version of this.
|
|
rawCookieSecret = "secretthirtytwobytes+abcdefghijk"
|
|
base64CookieSecret = "c2VjcmV0dGhpcnR5dHdvYnl0ZXMrYWJjZGVmZ2hpams"
|
|
clientID = "3984n253984d7348dm8234yf982t"
|
|
clientSecret = "gv3498mfc9t23y23974dm2394dm9"
|
|
)
|
|
|
|
func init() {
|
|
logger.SetFlags(logger.Lshortfile)
|
|
}
|
|
|
|
type WebSocketOrRestHandler struct {
|
|
restHandler http.Handler
|
|
wsHandler http.Handler
|
|
}
|
|
|
|
func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("Upgrade") == "websocket" {
|
|
h.wsHandler.ServeHTTP(w, r)
|
|
} else {
|
|
h.restHandler.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
|
|
func TestWebSocketProxy(t *testing.T) {
|
|
handler := WebSocketOrRestHandler{
|
|
restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
hostname, _, _ := net.SplitHostPort(r.Host)
|
|
_, err := w.Write([]byte(hostname))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}),
|
|
wsHandler: websocket.Handler(func(ws *websocket.Conn) {
|
|
defer func(t *testing.T) {
|
|
if err := ws.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}(t)
|
|
var data []byte
|
|
err := websocket.Message.Receive(ws, &data)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = websocket.Message.Send(ws, data)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}),
|
|
}
|
|
backend := httptest.NewServer(&handler)
|
|
t.Cleanup(backend.Close)
|
|
|
|
backendURL, _ := url.Parse(backend.URL)
|
|
|
|
opts := baseTestOptions()
|
|
var auth hmacauth.HmacAuth
|
|
opts.PassHostHeader = true
|
|
proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, opts, auth)
|
|
frontend := httptest.NewServer(proxyHandler)
|
|
t.Cleanup(frontend.Close)
|
|
|
|
frontendURL, _ := url.Parse(frontend.URL)
|
|
frontendWSURL := "ws://" + frontendURL.Host + "/"
|
|
|
|
ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
request := []byte("hello, world!")
|
|
err = websocket.Message.Send(ws, request)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var response = make([]byte, 1024)
|
|
err = websocket.Message.Receive(ws, &response)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if g, e := string(request), string(response); g != e {
|
|
t.Errorf("got body %q; expected %q", g, e)
|
|
}
|
|
|
|
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
|
res, _ := http.DefaultClient.Do(getReq)
|
|
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
|
backendHostname, _, _ := net.SplitHostPort(backendURL.Host)
|
|
if g, e := string(bodyBytes), backendHostname; g != e {
|
|
t.Errorf("got body %q; expected %q", g, e)
|
|
}
|
|
}
|
|
|
|
func TestNewReverseProxy(t *testing.T) {
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
hostname, _, _ := net.SplitHostPort(r.Host)
|
|
_, err := w.Write([]byte(hostname))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}))
|
|
t.Cleanup(backend.Close)
|
|
|
|
backendURL, _ := url.Parse(backend.URL)
|
|
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
|
|
backendHost := net.JoinHostPort(backendHostname, backendPort)
|
|
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
|
|
|
|
proxyHandler := NewReverseProxy(proxyURL, &options.Options{FlushInterval: time.Second})
|
|
setProxyUpstreamHostHeader(proxyHandler, proxyURL)
|
|
frontend := httptest.NewServer(proxyHandler)
|
|
t.Cleanup(frontend.Close)
|
|
|
|
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
|
res, _ := http.DefaultClient.Do(getReq)
|
|
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
|
if g, e := string(bodyBytes), backendHostname; g != e {
|
|
t.Errorf("got body %q; expected %q", g, e)
|
|
}
|
|
}
|
|
|
|
func TestEncodedSlashes(t *testing.T) {
|
|
var seen string
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
seen = r.RequestURI
|
|
}))
|
|
t.Cleanup(backend.Close)
|
|
|
|
b, _ := url.Parse(backend.URL)
|
|
proxyHandler := NewReverseProxy(b, &options.Options{FlushInterval: time.Second})
|
|
setProxyDirector(proxyHandler)
|
|
frontend := httptest.NewServer(proxyHandler)
|
|
t.Cleanup(frontend.Close)
|
|
|
|
f, _ := url.Parse(frontend.URL)
|
|
encodedPath := "/a%2Fb/?c=1"
|
|
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}}
|
|
_, err := http.DefaultClient.Do(getReq)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if seen != encodedPath {
|
|
t.Errorf("got bad request %q expected %q", seen, encodedPath)
|
|
}
|
|
}
|
|
|
|
func TestRobotsTxt(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/robots.txt", nil)
|
|
proxy.ServeHTTP(rw, req)
|
|
assert.Equal(t, 200, rw.Code)
|
|
assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String())
|
|
}
|
|
|
|
func TestIsValidRedirect(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
// Should match domains that are exactly foo.bar and any subdomain of bar.foo
|
|
opts.WhitelistDomains = []string{
|
|
"foo.bar",
|
|
".bar.foo",
|
|
"port.bar:8080",
|
|
".sub.port.bar:8080",
|
|
"anyport.bar:*",
|
|
".sub.anyport.bar:*",
|
|
}
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
testCases := []struct {
|
|
Desc, Redirect string
|
|
ExpectedResult bool
|
|
}{
|
|
{
|
|
Desc: "noRD",
|
|
Redirect: "",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "singleSlash",
|
|
Redirect: "/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "doubleSlash",
|
|
Redirect: "//redirect",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "validHTTP",
|
|
Redirect: "http://foo.bar/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "validHTTPS",
|
|
Redirect: "https://foo.bar/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "invalidHTTPSubdomain",
|
|
Redirect: "http://baz.foo.bar/redirect",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "invalidHTTPSSubdomain",
|
|
Redirect: "https://baz.foo.bar/redirect",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "validHTTPSubdomain",
|
|
Redirect: "http://baz.bar.foo/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "validHTTPSSubdomain",
|
|
Redirect: "https://baz.bar.foo/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "validHTTPDomain",
|
|
Redirect: "http://bar.foo/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "invalidHTTP1",
|
|
Redirect: "http://foo.bar.evil.corp/redirect",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "invalidHTTPS1",
|
|
Redirect: "https://foo.bar.evil.corp/redirect",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "invalidHTTP2",
|
|
Redirect: "http://evil.corp/redirect?rd=foo.bar",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "invalidHTTPS2",
|
|
Redirect: "https://evil.corp/redirect?rd=foo.bar",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "invalidPort",
|
|
Redirect: "https://evil.corp:3838/redirect",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "invalidEmptyPort",
|
|
Redirect: "http://foo.bar:3838/redirect",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "invalidEmptyPortSubdomain",
|
|
Redirect: "http://baz.bar.foo:3838/redirect",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "validSpecificPort",
|
|
Redirect: "http://port.bar:8080/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "invalidSpecificPort",
|
|
Redirect: "http://port.bar:3838/redirect",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "validSpecificPortSubdomain",
|
|
Redirect: "http://foo.sub.port.bar:8080/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "invalidSpecificPortSubdomain",
|
|
Redirect: "http://foo.sub.port.bar:3838/redirect",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "validAnyPort1",
|
|
Redirect: "http://anyport.bar:8080/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "validAnyPort2",
|
|
Redirect: "http://anyport.bar:8081/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "validAnyPortSubdomain1",
|
|
Redirect: "http://a.sub.anyport.bar:8080/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "validAnyPortSubdomain2",
|
|
Redirect: "http://a.sub.anyport.bar:8081/redirect",
|
|
ExpectedResult: true,
|
|
},
|
|
{
|
|
Desc: "openRedirect1",
|
|
Redirect: "/\\evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectSpace1",
|
|
Redirect: "/ /evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectSpace2",
|
|
Redirect: "/ \\evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectTab1",
|
|
Redirect: "/\t/evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectTab2",
|
|
Redirect: "/\t\\evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectVerticalTab1",
|
|
Redirect: "/\v/evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectVerticalTab2",
|
|
Redirect: "/\v\\evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectNewLine1",
|
|
Redirect: "/\n/evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectNewLine2",
|
|
Redirect: "/\n\\evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectCarriageReturn1",
|
|
Redirect: "/\r/evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectCarriageReturn2",
|
|
Redirect: "/\r\\evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectTripleTab",
|
|
Redirect: "/\t\t/\t/evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectTripleTab2",
|
|
Redirect: "/\t\t\\\t/evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectQuadTab1",
|
|
Redirect: "/\t\t/\t\t\\evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectQuadTab2",
|
|
Redirect: "/\t\t\\\t\t/evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectPeriod1",
|
|
Redirect: "/./\\evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectPeriod2",
|
|
Redirect: "/./../../\\evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
{
|
|
Desc: "openRedirectDoubleTab",
|
|
Redirect: "/\t/\t\\evil.com",
|
|
ExpectedResult: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.Desc, func(t *testing.T) {
|
|
result := proxy.IsValidRedirect(tc.Redirect)
|
|
|
|
if result != tc.ExpectedResult {
|
|
t.Errorf("expected %t got %t", tc.ExpectedResult, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOpenRedirects(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
// Should match domains that are exactly foo.bar and any subdomain of bar.foo
|
|
opts.WhitelistDomains = []string{
|
|
"foo.bar",
|
|
".bar.foo",
|
|
"port.bar:8080",
|
|
".sub.port.bar:8080",
|
|
"anyport.bar:*",
|
|
".sub.anyport.bar:*",
|
|
"www.whitelisteddomain.tld",
|
|
}
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
file, err := os.Open("./test/openredirects.txt")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer func(t *testing.T) {
|
|
if err := file.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}(t)
|
|
|
|
scanner := bufio.NewScanner(file)
|
|
for scanner.Scan() {
|
|
rd := scanner.Text()
|
|
t.Run(rd, func(t *testing.T) {
|
|
rdUnescaped, err := url.QueryUnescape(rd)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if proxy.IsValidRedirect(rdUnescaped) {
|
|
t.Errorf("Expected %q to not be valid (unescaped: %q)", rd, rdUnescaped)
|
|
}
|
|
})
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
type TestProvider struct {
|
|
*providers.ProviderData
|
|
EmailAddress string
|
|
ValidToken bool
|
|
GroupValidator func(string) bool
|
|
}
|
|
|
|
var _ providers.Provider = (*TestProvider)(nil)
|
|
|
|
func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
|
|
return &TestProvider{
|
|
ProviderData: &providers.ProviderData{
|
|
ProviderName: "Test Provider",
|
|
LoginURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: providerURL.Host,
|
|
Path: "/oauth/authorize",
|
|
},
|
|
RedeemURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: providerURL.Host,
|
|
Path: "/oauth/token",
|
|
},
|
|
ProfileURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: providerURL.Host,
|
|
Path: "/api/v1/profile",
|
|
},
|
|
Scope: "profile.email",
|
|
},
|
|
EmailAddress: emailAddress,
|
|
GroupValidator: func(s string) bool {
|
|
return true
|
|
},
|
|
}
|
|
}
|
|
|
|
func (tp *TestProvider) GetEmailAddress(ctx context.Context, session *sessions.SessionState) (string, error) {
|
|
return tp.EmailAddress, nil
|
|
}
|
|
|
|
func (tp *TestProvider) ValidateSessionState(ctx context.Context, session *sessions.SessionState) bool {
|
|
return tp.ValidToken
|
|
}
|
|
|
|
func TestBasicAuthPassword(t *testing.T) {
|
|
providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
logger.Printf("%#v", r)
|
|
var payload string
|
|
switch r.URL.Path {
|
|
case "/oauth/token":
|
|
payload = `{"access_token": "my_auth_token"}`
|
|
default:
|
|
payload = r.Header.Get("Authorization")
|
|
if payload == "" {
|
|
payload = "No Authorization header found."
|
|
}
|
|
}
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte(payload))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}))
|
|
opts := baseTestOptions()
|
|
opts.Upstreams = append(opts.Upstreams, providerServer.URL)
|
|
opts.Cookie.Secure = false
|
|
opts.PassBasicAuth = true
|
|
opts.SetBasicAuth = true
|
|
opts.PassUserHeaders = true
|
|
opts.PreferEmailToUser = true
|
|
opts.BasicAuthPassword = "This is a secure password"
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
providerURL, _ := url.Parse(providerServer.URL)
|
|
const emailAddress = "john.doe@example.com"
|
|
|
|
opts.SetProvider(NewTestProvider(providerURL, emailAddress))
|
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader(""))
|
|
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
|
|
proxy.ServeHTTP(rw, req)
|
|
if rw.Code >= 400 {
|
|
t.Fatalf("expected 3xx got %d", rw.Code)
|
|
}
|
|
cookie := rw.Header().Values("Set-Cookie")[1]
|
|
|
|
cookieName := proxy.CookieName
|
|
var value string
|
|
keyPrefix := cookieName + "="
|
|
|
|
for _, field := range strings.Split(cookie, "; ") {
|
|
value = strings.TrimPrefix(field, keyPrefix)
|
|
if value != field {
|
|
break
|
|
} else {
|
|
value = ""
|
|
}
|
|
}
|
|
|
|
req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
|
|
req.AddCookie(&http.Cookie{
|
|
Name: cookieName,
|
|
Value: value,
|
|
Path: "/",
|
|
Expires: time.Now().Add(time.Duration(24)),
|
|
HttpOnly: true,
|
|
})
|
|
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
|
|
|
|
rw = httptest.NewRecorder()
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
// The username in the basic auth credentials is expected to be equal to the email address from the
|
|
// auth response, so we use the same variable here.
|
|
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword))
|
|
assert.Equal(t, expectedHeader, rw.Body.String())
|
|
providerServer.Close()
|
|
}
|
|
|
|
func TestBasicAuthWithEmail(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.PassBasicAuth = true
|
|
opts.PassUserHeaders = false
|
|
opts.PreferEmailToUser = false
|
|
opts.BasicAuthPassword = "This is a secure password"
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
const emailAddress = "john.doe@example.com"
|
|
const userName = "9fcab5c9b889a557"
|
|
|
|
// The username in the basic auth credentials is expected to be equal to the email address from the
|
|
expectedEmailHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword))
|
|
expectedUserHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(userName+":"+opts.BasicAuthPassword))
|
|
|
|
created := time.Now()
|
|
session := &sessions.SessionState{
|
|
User: userName,
|
|
Email: emailAddress,
|
|
AccessToken: "oauth_token",
|
|
CreatedAt: &created,
|
|
}
|
|
{
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase0", nil)
|
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxy.addHeadersForProxying(rw, req, session)
|
|
assert.Equal(t, expectedUserHeader, req.Header["Authorization"][0])
|
|
assert.Equal(t, userName, req.Header["X-Forwarded-User"][0])
|
|
}
|
|
|
|
opts.PreferEmailToUser = true
|
|
{
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase1", nil)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxy.addHeadersForProxying(rw, req, session)
|
|
assert.Equal(t, expectedEmailHeader, req.Header["Authorization"][0])
|
|
assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0])
|
|
}
|
|
}
|
|
|
|
func TestPassUserHeadersWithEmail(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
const emailAddress = "john.doe@example.com"
|
|
const userName = "9fcab5c9b889a557"
|
|
|
|
created := time.Now()
|
|
session := &sessions.SessionState{
|
|
User: userName,
|
|
Email: emailAddress,
|
|
AccessToken: "oauth_token",
|
|
CreatedAt: &created,
|
|
}
|
|
{
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase0", nil)
|
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxy.addHeadersForProxying(rw, req, session)
|
|
assert.Equal(t, userName, req.Header["X-Forwarded-User"][0])
|
|
}
|
|
|
|
opts.PreferEmailToUser = true
|
|
{
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase1", nil)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxy.addHeadersForProxying(rw, req, session)
|
|
assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0])
|
|
}
|
|
}
|
|
|
|
type PassAccessTokenTest struct {
|
|
providerServer *httptest.Server
|
|
proxy *OAuthProxy
|
|
opts *options.Options
|
|
}
|
|
|
|
type PassAccessTokenTestOptions struct {
|
|
PassAccessToken bool
|
|
ProxyUpstream string
|
|
}
|
|
|
|
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) {
|
|
patt := &PassAccessTokenTest{}
|
|
|
|
patt.providerServer = httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var payload string
|
|
switch r.URL.Path {
|
|
case "/oauth/token":
|
|
payload = `{"access_token": "my_auth_token"}`
|
|
default:
|
|
payload = r.Header.Get("X-Forwarded-Access-Token")
|
|
if payload == "" {
|
|
payload = "No access token found."
|
|
}
|
|
}
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte(payload))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}))
|
|
|
|
patt.opts = baseTestOptions()
|
|
patt.opts.Upstreams = append(patt.opts.Upstreams, patt.providerServer.URL)
|
|
if opts.ProxyUpstream != "" {
|
|
patt.opts.Upstreams = append(patt.opts.Upstreams, opts.ProxyUpstream)
|
|
}
|
|
patt.opts.Cookie.Secure = false
|
|
patt.opts.PassAccessToken = opts.PassAccessToken
|
|
err := validation.Validate(patt.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerURL, _ := url.Parse(patt.providerServer.URL)
|
|
const emailAddress = "michael.bland@gsa.gov"
|
|
|
|
patt.opts.SetProvider(NewTestProvider(providerURL, emailAddress))
|
|
patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return patt, nil
|
|
}
|
|
|
|
func (patTest *PassAccessTokenTest) Close() {
|
|
patTest.providerServer.Close()
|
|
}
|
|
|
|
func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int,
|
|
cookie string) {
|
|
rw := httptest.NewRecorder()
|
|
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
|
|
strings.NewReader(""))
|
|
if err != nil {
|
|
return 0, ""
|
|
}
|
|
req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
|
|
patTest.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Header().Values("Set-Cookie")[1]
|
|
}
|
|
|
|
// getEndpointWithCookie makes a requests againt the oauthproxy with passed requestPath
|
|
// and cookie and returns body and status code.
|
|
func (patTest *PassAccessTokenTest) getEndpointWithCookie(cookie string, endpoint string) (httpCode int, accessToken string) {
|
|
cookieName := patTest.proxy.CookieName
|
|
var value string
|
|
keyPrefix := cookieName + "="
|
|
|
|
for _, field := range strings.Split(cookie, "; ") {
|
|
value = strings.TrimPrefix(field, keyPrefix)
|
|
if value != field {
|
|
break
|
|
} else {
|
|
value = ""
|
|
}
|
|
}
|
|
if value == "" {
|
|
return 0, ""
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", endpoint, strings.NewReader(""))
|
|
if err != nil {
|
|
return 0, ""
|
|
}
|
|
req.AddCookie(&http.Cookie{
|
|
Name: cookieName,
|
|
Value: value,
|
|
Path: "/",
|
|
Expires: time.Now().Add(time.Duration(24)),
|
|
HttpOnly: true,
|
|
})
|
|
|
|
rw := httptest.NewRecorder()
|
|
patTest.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Body.String()
|
|
}
|
|
|
|
func TestForwardAccessTokenUpstream(t *testing.T) {
|
|
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: true,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(patTest.Close)
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := patTest.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotNil(t, cookie)
|
|
|
|
// Now we make a regular request; the access_token from the cookie is
|
|
// forwarded as the "X-Forwarded-Access-Token" header. The token is
|
|
// read by the test provider server and written in the response body.
|
|
code, payload := patTest.getEndpointWithCookie(cookie, "/")
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "my_auth_token", payload)
|
|
}
|
|
|
|
func TestStaticProxyUpstream(t *testing.T) {
|
|
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: true,
|
|
ProxyUpstream: "static://200/static-proxy",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(patTest.Close)
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := patTest.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotEqual(t, nil, cookie)
|
|
|
|
// Now we make a regular request against the upstream proxy; And validate
|
|
// the returned status code through the static proxy.
|
|
code, payload := patTest.getEndpointWithCookie(cookie, "/static-proxy")
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "Authenticated", payload)
|
|
}
|
|
|
|
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
|
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(patTest.Close)
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := patTest.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotEqual(t, nil, cookie)
|
|
|
|
// Now we make a regular request, but the access token header should
|
|
// not be present.
|
|
code, payload := patTest.getEndpointWithCookie(cookie, "/")
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "No access token found.", payload)
|
|
}
|
|
|
|
type SignInPageTest struct {
|
|
opts *options.Options
|
|
proxy *OAuthProxy
|
|
signInRegexp *regexp.Regexp
|
|
signInProviderRegexp *regexp.Regexp
|
|
}
|
|
|
|
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
|
|
const signInSkipProvider = `>Found<`
|
|
|
|
func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) {
|
|
var sipTest SignInPageTest
|
|
|
|
sipTest.opts = baseTestOptions()
|
|
sipTest.opts.SkipProviderButton = skipProvider
|
|
err := validation.Validate(sipTest.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool {
|
|
return true
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
|
|
sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
|
|
|
|
return &sipTest, nil
|
|
}
|
|
|
|
func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
|
|
sipTest.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Body.String()
|
|
}
|
|
|
|
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
const endpoint = "/some/random/endpoint"
|
|
|
|
code, body := sipTest.GetEndpoint(endpoint)
|
|
assert.Equal(t, 403, code)
|
|
|
|
match := sipTest.signInRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInRedirectPattern + "\nBody:\n" + body)
|
|
}
|
|
if match[1] != endpoint {
|
|
t.Fatal(`expected redirect to "` + endpoint +
|
|
`", but was "` + match[1] + `"`)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
code, body := sipTest.GetEndpoint("/oauth2/sign_in")
|
|
assert.Equal(t, 200, code)
|
|
|
|
match := sipTest.signInRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInRedirectPattern + "\nBody:\n" + body)
|
|
}
|
|
if match[1] != "/" {
|
|
t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageSkipProvider(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
endpoint := "/some/random/endpoint"
|
|
|
|
code, body := sipTest.GetEndpoint(endpoint)
|
|
assert.Equal(t, 302, code)
|
|
|
|
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInSkipProvider + "\nBody:\n" + body)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageSkipProviderDirect(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
endpoint := "/sign_in"
|
|
|
|
code, body := sipTest.GetEndpoint(endpoint)
|
|
assert.Equal(t, 302, code)
|
|
|
|
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInSkipProvider + "\nBody:\n" + body)
|
|
}
|
|
}
|
|
|
|
type ProcessCookieTest struct {
|
|
opts *options.Options
|
|
proxy *OAuthProxy
|
|
rw *httptest.ResponseRecorder
|
|
req *http.Request
|
|
validateUser bool
|
|
}
|
|
|
|
type ProcessCookieTestOpts struct {
|
|
providerValidateCookieResponse bool
|
|
}
|
|
|
|
type OptionsModifier func(*options.Options)
|
|
|
|
func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
for _, modifier := range modifiers {
|
|
modifier(pcTest.opts)
|
|
}
|
|
// First, set the CookieRefresh option so proxy.AesCipher is created,
|
|
// needed to encrypt the access_token.
|
|
pcTest.opts.Cookie.Refresh = time.Hour
|
|
err := validation.Validate(pcTest.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ValidToken: opts.providerValidateCookieResponse,
|
|
}
|
|
|
|
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
|
// access_token validation.
|
|
pcTest.proxy.CookieRefresh = time.Duration(0)
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
|
|
pcTest.validateUser = true
|
|
return &pcTest, nil
|
|
}
|
|
|
|
func NewProcessCookieTestWithDefaults() (*ProcessCookieTest, error) {
|
|
return NewProcessCookieTest(ProcessCookieTestOpts{
|
|
providerValidateCookieResponse: true,
|
|
})
|
|
}
|
|
|
|
func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
|
return NewProcessCookieTest(ProcessCookieTestOpts{
|
|
providerValidateCookieResponse: true,
|
|
}, modifiers...)
|
|
}
|
|
|
|
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error {
|
|
err := p.proxy.SaveSession(p.rw, p.req, s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, cookie := range p.rw.Result().Cookies() {
|
|
p.req.AddCookie(cookie)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) {
|
|
return p.proxy.LoadCookiedSession(p.req)
|
|
}
|
|
|
|
func TestLoadCookiedSession(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithDefaults()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
assert.Equal(t, startSession.Email, session.Email)
|
|
assert.Equal(t, "", session.User)
|
|
assert.Equal(t, startSession.AccessToken, session.AccessToken)
|
|
}
|
|
|
|
func TestProcessCookieNoCookieError(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithDefaults()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.Error(t, err, "cookie \"_oauth2_proxy\" not present")
|
|
if session != nil {
|
|
t.Errorf("expected nil session. got %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestProcessCookieRefreshNotSet(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(23) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
|
|
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.Equal(t, nil, err)
|
|
if session.Age() < time.Duration(-2)*time.Hour {
|
|
t.Errorf("cookie too young %v", session.Age())
|
|
}
|
|
assert.Equal(t, startSession.Email, session.Email)
|
|
}
|
|
|
|
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.NotEqual(t, nil, err)
|
|
if session != nil {
|
|
t.Errorf("expected nil session %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.CookieRefresh = time.Hour
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.NotEqual(t, nil, err)
|
|
if session != nil {
|
|
t.Errorf("expected nil session %#v", session)
|
|
}
|
|
}
|
|
|
|
func NewUserInfoEndpointTest() (*ProcessCookieTest, error) {
|
|
pcTest, err := NewProcessCookieTestWithDefaults()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/userinfo", nil)
|
|
return pcTest, nil
|
|
}
|
|
|
|
func TestUserInfoEndpointAccepted(t *testing.T) {
|
|
test, err := NewUserInfoEndpointTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
startSession := &sessions.SessionState{
|
|
Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
|
err = test.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusOK, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
|
test, err := NewUserInfoEndpointTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
}
|
|
|
|
func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
|
return pcTest, nil
|
|
}
|
|
|
|
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
|
|
err = test.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusAccepted, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &sessions.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = test.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
|
|
err = test.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
test.validateUser = false
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
pcTest.opts.SetXAuthRequest = true
|
|
err := validation.Validate(pcTest.opts)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ValidToken: true,
|
|
}
|
|
|
|
pcTest.validateUser = true
|
|
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
|
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
|
assert.Equal(t, "oauth_user", pcTest.rw.Header().Get("X-Auth-Request-User"))
|
|
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Get("X-Auth-Request-Email"))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
pcTest.opts.SetXAuthRequest = true
|
|
pcTest.opts.SetBasicAuth = true
|
|
err := validation.Validate(pcTest.opts)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ValidToken: true,
|
|
}
|
|
|
|
pcTest.validateUser = true
|
|
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
|
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
|
assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
|
|
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
|
|
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("oauth_user:"+pcTest.opts.BasicAuthPassword))
|
|
assert.Equal(t, expectedHeader, pcTest.rw.Header().Values("Authorization")[0])
|
|
}
|
|
|
|
func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
pcTest.opts.SetXAuthRequest = true
|
|
pcTest.opts.SetBasicAuth = false
|
|
err := validation.Validate(pcTest.opts)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ValidToken: true,
|
|
}
|
|
|
|
pcTest.validateUser = true
|
|
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
|
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
|
assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
|
|
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
|
|
assert.Equal(t, 0, len(pcTest.rw.Header().Values("Authorization")), "should not have Authorization header entries")
|
|
}
|
|
|
|
func TestAuthSkippedForPreflightRequests(t *testing.T) {
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte("response"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}))
|
|
t.Cleanup(upstream.Close)
|
|
|
|
opts := baseTestOptions()
|
|
opts.Upstreams = append(opts.Upstreams, upstream.URL)
|
|
opts.SkipAuthPreflight = true
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
upstreamURL, _ := url.Parse(upstream.URL)
|
|
opts.SetProvider(NewTestProvider(upstreamURL, ""))
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
assert.Equal(t, 200, rw.Code)
|
|
assert.Equal(t, "response", rw.Body.String())
|
|
}
|
|
|
|
type SignatureAuthenticator struct {
|
|
auth hmacauth.HmacAuth
|
|
}
|
|
|
|
func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) {
|
|
result, headerSig, computedSig := v.auth.AuthenticateRequest(r)
|
|
|
|
var msg string
|
|
switch result {
|
|
case hmacauth.ResultNoSignature:
|
|
msg = "no signature received"
|
|
case hmacauth.ResultMatch:
|
|
msg = "signatures match"
|
|
case hmacauth.ResultMismatch:
|
|
msg = fmt.Sprintf(
|
|
"signatures do not match:\n received: %s\n computed: %s",
|
|
headerSig,
|
|
computedSig)
|
|
default:
|
|
panic("unknown result value: " + result.String())
|
|
}
|
|
|
|
_, err := w.Write([]byte(msg))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
type SignatureTest struct {
|
|
opts *options.Options
|
|
upstream *httptest.Server
|
|
upstreamHost string
|
|
provider *httptest.Server
|
|
header http.Header
|
|
rw *httptest.ResponseRecorder
|
|
authenticator *SignatureAuthenticator
|
|
}
|
|
|
|
func NewSignatureTest() (*SignatureTest, error) {
|
|
opts := baseTestOptions()
|
|
opts.EmailDomains = []string{"acm.org"}
|
|
|
|
authenticator := &SignatureAuthenticator{}
|
|
upstream := httptest.NewServer(
|
|
http.HandlerFunc(authenticator.Authenticate))
|
|
upstreamURL, err := url.Parse(upstream.URL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
opts.Upstreams = append(opts.Upstreams, upstream.URL)
|
|
|
|
providerHandler := func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte(`{"access_token": "my_auth_token"}`))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
provider := httptest.NewServer(http.HandlerFunc(providerHandler))
|
|
providerURL, err := url.Parse(provider.URL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org"))
|
|
|
|
return &SignatureTest{
|
|
opts,
|
|
upstream,
|
|
upstreamURL.Host,
|
|
provider,
|
|
make(http.Header),
|
|
httptest.NewRecorder(),
|
|
authenticator,
|
|
}, nil
|
|
}
|
|
|
|
func (st *SignatureTest) Close() {
|
|
st.provider.Close()
|
|
st.upstream.Close()
|
|
}
|
|
|
|
// fakeNetConn simulates an http.Request.Body buffer that will be consumed
|
|
// when it is read by the hmacauth.HmacAuth if not handled properly. See:
|
|
// https://github.com/18F/hmacauth/pull/4
|
|
type fakeNetConn struct {
|
|
reqBody string
|
|
}
|
|
|
|
func (fnc *fakeNetConn) Read(p []byte) (n int, err error) {
|
|
if bodyLen := len(fnc.reqBody); bodyLen != 0 {
|
|
copy(p, fnc.reqBody)
|
|
fnc.reqBody = ""
|
|
return bodyLen, io.EOF
|
|
}
|
|
return 0, io.EOF
|
|
}
|
|
|
|
func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) error {
|
|
err := validation.Validate(st.opts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true })
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var bodyBuf io.ReadCloser
|
|
if body != "" {
|
|
bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body})
|
|
}
|
|
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
|
|
req.Header = st.header
|
|
|
|
state := &sessions.SessionState{
|
|
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
|
err = proxy.SaveSession(st.rw, req, state)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, c := range st.rw.Result().Cookies() {
|
|
req.AddCookie(c)
|
|
}
|
|
// This is used by the upstream to validate the signature.
|
|
st.authenticator.auth = hmacauth.NewHmacAuth(
|
|
crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)
|
|
proxy.ServeHTTP(st.rw, req)
|
|
|
|
return nil
|
|
}
|
|
|
|
func TestRequestSignature(t *testing.T) {
|
|
testCases := map[string]struct {
|
|
method string
|
|
body string
|
|
key string
|
|
resp string
|
|
}{
|
|
"No request signature": {
|
|
method: "GET",
|
|
body: "",
|
|
key: "",
|
|
resp: "no signature received",
|
|
},
|
|
"Get request": {
|
|
method: "GET",
|
|
body: "",
|
|
key: "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d",
|
|
resp: "signatures match",
|
|
},
|
|
"Post request": {
|
|
method: "POST",
|
|
body: `{ "hello": "world!" }`,
|
|
key: "d90df39e2d19282840252612dd7c81421a372f61",
|
|
resp: "signatures match",
|
|
},
|
|
}
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
st, err := NewSignatureTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(st.Close)
|
|
if tc.key != "" {
|
|
st.opts.SignatureKey = fmt.Sprintf("sha1:%s", tc.key)
|
|
}
|
|
err = st.MakeRequestWithExpectedKey(tc.method, tc.body, tc.key)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 200, st.rw.Code)
|
|
assert.Equal(t, tc.resp, st.rw.Body.String())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetRedirect(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
require.NotEmpty(t, opts.ProxyPrefix)
|
|
proxy, err := NewOAuthProxy(opts, func(s string) bool { return false })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
url string
|
|
expectedRedirect string
|
|
}{
|
|
{
|
|
name: "request outside of ProxyPrefix redirects to original URL",
|
|
url: "/foo/bar",
|
|
expectedRedirect: "/foo/bar",
|
|
},
|
|
{
|
|
name: "request under ProxyPrefix redirects to root",
|
|
url: proxy.ProxyPrefix + "/foo/bar",
|
|
expectedRedirect: "/",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest("GET", tt.url, nil)
|
|
redirect, err := proxy.GetRedirect(req)
|
|
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expectedRedirect, redirect)
|
|
})
|
|
}
|
|
}
|
|
|
|
type ajaxRequestTest struct {
|
|
opts *options.Options
|
|
proxy *OAuthProxy
|
|
}
|
|
|
|
func newAjaxRequestTest() (*ajaxRequestTest, error) {
|
|
test := &ajaxRequestTest{}
|
|
test.opts = baseTestOptions()
|
|
err := validation.Validate(test.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool {
|
|
return true
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return test, nil
|
|
}
|
|
|
|
func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) {
|
|
rw := httptest.NewRecorder()
|
|
req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader(""))
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
req.Header = header
|
|
test.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Header(), nil
|
|
}
|
|
|
|
func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) {
|
|
test, err := newAjaxRequestTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
endpoint := "/test"
|
|
|
|
code, rh, err := test.getEndpoint(endpoint, header)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, http.StatusUnauthorized, code)
|
|
mime := rh.Get("Content-Type")
|
|
assert.Equal(t, applicationJSON, mime)
|
|
}
|
|
func TestAjaxUnauthorizedRequest1(t *testing.T) {
|
|
header := make(http.Header)
|
|
header.Add("accept", applicationJSON)
|
|
|
|
testAjaxUnauthorizedRequest(t, header)
|
|
}
|
|
|
|
func TestAjaxUnauthorizedRequest2(t *testing.T) {
|
|
header := make(http.Header)
|
|
header.Add("Accept", applicationJSON)
|
|
|
|
testAjaxUnauthorizedRequest(t, header)
|
|
}
|
|
|
|
func TestAjaxForbiddendRequest(t *testing.T) {
|
|
test, err := newAjaxRequestTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
endpoint := "/test"
|
|
header := make(http.Header)
|
|
code, rh, err := test.getEndpoint(endpoint, header)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, http.StatusForbidden, code)
|
|
mime := rh.Get("Content-Type")
|
|
assert.NotEqual(t, applicationJSON, mime)
|
|
}
|
|
|
|
func TestClearSplitCookie(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.Cookie.Secret = base64CookieSecret
|
|
opts.Cookie.Name = "oauth2"
|
|
opts.Cookie.Domains = []string{"abc"}
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store}
|
|
var rw = httptest.NewRecorder()
|
|
req := httptest.NewRequest("get", "/", nil)
|
|
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "test1",
|
|
Value: "test1",
|
|
})
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "oauth2_0",
|
|
Value: "oauth2_0",
|
|
})
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "oauth2_1",
|
|
Value: "oauth2_1",
|
|
})
|
|
|
|
err = p.ClearSessionCookie(rw, req)
|
|
assert.NoError(t, err)
|
|
header := rw.Header()
|
|
|
|
assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries")
|
|
}
|
|
|
|
func TestClearSingleCookie(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.Cookie.Name = "oauth2"
|
|
opts.Cookie.Domains = []string{"abc"}
|
|
store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store}
|
|
var rw = httptest.NewRecorder()
|
|
req := httptest.NewRequest("get", "/", nil)
|
|
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "test1",
|
|
Value: "test1",
|
|
})
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "oauth2",
|
|
Value: "oauth2",
|
|
})
|
|
|
|
err = p.ClearSessionCookie(rw, req)
|
|
assert.NoError(t, err)
|
|
header := rw.Header()
|
|
|
|
assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries")
|
|
}
|
|
|
|
type NoOpKeySet struct {
|
|
}
|
|
|
|
func (NoOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
|
|
splitStrings := strings.Split(jwt, ".")
|
|
payloadString := splitStrings[1]
|
|
return base64.RawURLEncoding.DecodeString(payloadString)
|
|
}
|
|
|
|
func TestGetJwtSession(t *testing.T) {
|
|
/* token payload:
|
|
{
|
|
"sub": "1234567890",
|
|
"aud": "https://test.myapp.com",
|
|
"name": "John Doe",
|
|
"email": "john@example.com",
|
|
"iss": "https://issuer.example.com",
|
|
"iat": 1553691215,
|
|
"exp": 1912151821
|
|
}
|
|
*/
|
|
goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
|
|
"eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" +
|
|
"WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" +
|
|
"E1LCJleHAiOjE5MTIxNTE4MjF9." +
|
|
"rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" +
|
|
"OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8"
|
|
|
|
keyset := NoOpKeySet{}
|
|
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
|
|
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
|
|
|
|
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
|
|
opts.PassAuthorization = true
|
|
opts.SetAuthorization = true
|
|
opts.SetXAuthRequest = true
|
|
opts.SkipJwtBearerTokens = true
|
|
opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier))
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tp, _ := test.proxy.provider.(*TestProvider)
|
|
tp.GroupValidator = func(s string) bool {
|
|
return true
|
|
}
|
|
|
|
authHeader := fmt.Sprintf("Bearer %s", goodJwt)
|
|
test.req.Header = map[string][]string{
|
|
"Authorization": {authHeader},
|
|
}
|
|
|
|
// Bearer
|
|
expires := time.Unix(1912151821, 0)
|
|
session, err := test.proxy.GetJwtSession(test.req)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, session.User, "1234567890")
|
|
assert.Equal(t, session.Email, "john@example.com")
|
|
assert.Equal(t, session.ExpiresOn, &expires)
|
|
assert.Equal(t, session.IDToken, goodJwt)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
if test.rw.Code >= 400 {
|
|
t.Fatalf("expected 3xx got %d", test.rw.Code)
|
|
}
|
|
|
|
// Check PassAuthorization, should overwrite Basic header
|
|
assert.Equal(t, test.req.Header.Get("Authorization"), authHeader)
|
|
assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "1234567890")
|
|
assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com")
|
|
|
|
// SetAuthorization and SetXAuthRequest
|
|
assert.Equal(t, test.rw.Header().Get("Authorization"), authHeader)
|
|
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "1234567890")
|
|
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
|
|
}
|
|
|
|
func TestFindJwtBearerToken(t *testing.T) {
|
|
p := OAuthProxy{CookieName: "oauth2", CookieDomains: []string{"abc"}}
|
|
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}}
|
|
|
|
validToken := "eyJfoobar.eyJfoobar.12345asdf"
|
|
var token string
|
|
|
|
// Bearer
|
|
getReq.Header = map[string][]string{
|
|
"Authorization": {fmt.Sprintf("Bearer %s", validToken)},
|
|
}
|
|
|
|
token, err := p.findBearerToken(getReq)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, validToken, token)
|
|
|
|
// Basic - no password
|
|
getReq.SetBasicAuth(token, "")
|
|
token, err = p.findBearerToken(getReq)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, validToken, token)
|
|
|
|
// Basic - sentinel password
|
|
getReq.SetBasicAuth(token, "x-oauth-basic")
|
|
token, err = p.findBearerToken(getReq)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, validToken, token)
|
|
|
|
// Basic - any username, password matching jwt pattern
|
|
getReq.SetBasicAuth("any-username-you-could-wish-for", token)
|
|
token, err = p.findBearerToken(getReq)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, validToken, token)
|
|
|
|
failures := []string{
|
|
// Too many parts
|
|
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
|
// Not enough parts
|
|
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA",
|
|
// Invalid encrypted key
|
|
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.//////.dGVzdA.dGVzdA.dGVzdA",
|
|
// Invalid IV
|
|
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.//////.dGVzdA.dGVzdA",
|
|
// Invalid ciphertext
|
|
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.//////.dGVzdA",
|
|
// Invalid tag
|
|
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.//////",
|
|
// Invalid header
|
|
"W10.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
|
// Invalid header
|
|
"######.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
|
// Missing alc/enc params
|
|
"e30.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
|
|
}
|
|
|
|
for _, failure := range failures {
|
|
getReq.Header = map[string][]string{
|
|
"Authorization": {fmt.Sprintf("Bearer %s", failure)},
|
|
}
|
|
_, err := p.findBearerToken(getReq)
|
|
assert.Error(t, err)
|
|
}
|
|
}
|
|
|
|
func Test_prepareNoCache(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
prepareNoCache(w)
|
|
})
|
|
mux := http.NewServeMux()
|
|
mux.Handle("/", handler)
|
|
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
for k, v := range noCacheHeaders {
|
|
assert.Equal(t, rec.Header().Get(k), v)
|
|
}
|
|
}
|
|
|
|
func Test_noCacheHeaders(t *testing.T) {
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte("upstream"))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
}))
|
|
t.Cleanup(upstream.Close)
|
|
|
|
opts := baseTestOptions()
|
|
opts.Upstreams = []string{upstream.URL}
|
|
opts.SkipAuthRegex = []string{".*"}
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Run("not exist in response from upstream", func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/upstream", nil)
|
|
proxy.ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "upstream", rec.Body.String())
|
|
|
|
// checking noCacheHeaders does not exists in response headers from upstream
|
|
for k := range noCacheHeaders {
|
|
assert.Equal(t, "", rec.Header().Get(k))
|
|
}
|
|
})
|
|
|
|
t.Run("has no-cache", func(t *testing.T) {
|
|
tests := []struct {
|
|
path string
|
|
hasNoCache bool
|
|
}{
|
|
{
|
|
path: "/oauth2/sign_in",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/sign_out",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/start",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/callback",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/auth",
|
|
hasNoCache: false,
|
|
},
|
|
{
|
|
path: "/oauth2/userinfo",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/upstream",
|
|
hasNoCache: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.path, func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, tt.path, nil)
|
|
proxy.ServeHTTP(rec, req)
|
|
cacheControl := rec.Result().Header.Get("Cache-Control")
|
|
if tt.hasNoCache != (strings.Contains(cacheControl, "no-cache")) {
|
|
t.Errorf(`unexpected "Cache-Control" header: %s`, cacheControl)
|
|
}
|
|
})
|
|
}
|
|
|
|
})
|
|
}
|
|
|
|
func baseTestOptions() *options.Options {
|
|
opts := options.NewOptions()
|
|
opts.Cookie.Secret = rawCookieSecret
|
|
opts.ClientID = clientID
|
|
opts.ClientSecret = clientSecret
|
|
opts.EmailDomains = []string{"*"}
|
|
return opts
|
|
}
|
|
|
|
func TestTrustedIPs(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
trustedIPs []string
|
|
reverseProxy bool
|
|
realClientIPHeader string
|
|
req *http.Request
|
|
expectTrusted bool
|
|
}{
|
|
// Check unconfigured behavior.
|
|
{
|
|
name: "Default",
|
|
trustedIPs: nil,
|
|
reverseProxy: false,
|
|
realClientIPHeader: "X-Real-IP", // Default value
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check using req.RemoteAddr (Options.ReverseProxy == false).
|
|
{
|
|
name: "WithRemoteAddr",
|
|
trustedIPs: []string{"127.0.0.1"},
|
|
reverseProxy: false,
|
|
realClientIPHeader: "X-Real-IP", // Default value
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "127.0.0.1:43670"
|
|
return req
|
|
}(),
|
|
expectTrusted: true,
|
|
},
|
|
// Check ignores req.RemoteAddr match when behind a reverse proxy / missing header.
|
|
{
|
|
name: "IgnoresRemoteAddrInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Real-IP", // Default value
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "127.0.0.1:44324"
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check successful trusting of localhost in IPv4.
|
|
{
|
|
name: "TrustsLocalhostInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "127.0.0.1")
|
|
return req
|
|
}(),
|
|
expectTrusted: true,
|
|
},
|
|
// Check successful trusting of localhost in IPv6.
|
|
{
|
|
name: "TrustsIP6LocalostInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "::1")
|
|
return req
|
|
}(),
|
|
expectTrusted: true,
|
|
},
|
|
// Check does not trust random IPv4 address.
|
|
{
|
|
name: "DoesNotTrustRandomIP4Address",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "12.34.56.78")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check does not trust random IPv6 address.
|
|
{
|
|
name: "DoesNotTrustRandomIP6Address",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "::2")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check respects correct header.
|
|
{
|
|
name: "RespectsCorrectHeaderInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Real-IP", "::1")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check doesn't trust if garbage is provided.
|
|
{
|
|
name: "DoesNotTrustGarbageInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "adsfljk29242as!!")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check doesn't trust if garbage is provided (no reverse-proxy).
|
|
{
|
|
name: "DoesNotTrustGarbage",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: false,
|
|
realClientIPHeader: "X-Real-IP",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "adsfljk29242as!!"
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.Upstreams = []string{"static://200"}
|
|
opts.TrustedIPs = tt.trustedIPs
|
|
opts.ReverseProxy = tt.reverseProxy
|
|
opts.RealClientIPHeader = tt.realClientIPHeader
|
|
validation.Validate(opts)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
assert.NoError(t, err)
|
|
rw := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rw, tt.req)
|
|
if tt.expectTrusted {
|
|
assert.Equal(t, 200, rw.Code)
|
|
} else {
|
|
assert.Equal(t, 403, rw.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|