1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-10 04:18:14 +02:00
oauth2-proxy/oauthproxy_test.go
Kevin Schu 25371ea4af
improved audience handling to support client credentials access tokens without aud claims (#1204)
* implementation draft

* add cfg options skip-au-when-missing && client-id-verification-claim; enhance the provider data verification logic for sake of the added options

* refactor configs, added logging and add additional claim verification

* simplify logic by just having one configuration similar to oidc-email-claim

* added internal oidc token verifier, so that aud check behavior can be managed with oauth2-proxy and is compatible with extra-jwt-issuers

* refactored verification to reduce complexity

* refactored verification to reduce complexity

* added docs

* adjust tests to support new OIDCAudienceClaim and OIDCExtraAudiences options

* extend unit tests and ensure that audience is set with the value of aud claim configuration

* revert filemodes and update docs

* update docs

* remove unneccesary logging, refactor audience existence check and added additional unit tests

* fix linting issues after rebase on origin/main

* cleanup: use new imports for migrated libraries after rebase on origin/main

* adapt mock in keycloak_oidc_test.go

* allow specifying multiple audience claims, fixed bug where jwt issuers client id was not the being considered and fixed bug where aud claims with multiple audiences has broken the whole validation

* fixed formatting issue

* do not pass the whole options struct to minimize complexity and dependency to the configuration structure

* added changelog entry

* update docs

Co-authored-by: Sofia Weiler <sofia.weiler@aoe.com>
Co-authored-by: Christian Zenker <christian.zenker@aoe.com>
2022-02-15 16:12:22 +00:00

2694 lines
69 KiB
Go

package main
import (
"context"
"crypto"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/mbland/hmacauth"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc"
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/cookie"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
// The rawCookieSecret is 32 bytes and the base64CookieSecret is the base64
// encoded version of this.
rawCookieSecret = "secretthirtytwobytes+abcdefghijk"
base64CookieSecret = "c2VjcmV0dGhpcnR5dHdvYnl0ZXMrYWJjZGVmZ2hpams"
clientID = "3984n253984d7348dm8234yf982t"
clientSecret = "gv3498mfc9t23y23974dm2394dm9"
)
func init() {
logger.SetFlags(logger.Lshortfile)
}
func TestRobotsTxt(t *testing.T) {
opts := baseTestOptions()
err := validation.Validate(opts)
assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
if err != nil {
t.Fatal(err)
}
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/robots.txt", nil)
proxy.ServeHTTP(rw, req)
assert.Equal(t, 200, rw.Code)
assert.Equal(t, "User-agent: *\nDisallow: /\n", rw.Body.String())
}
type TestProvider struct {
*providers.ProviderData
EmailAddress string
ValidToken bool
GroupValidator func(string) bool
}
var _ providers.Provider = (*TestProvider)(nil)
func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
return &TestProvider{
ProviderData: &providers.ProviderData{
ProviderName: "Test Provider",
LoginURL: &url.URL{
Scheme: "http",
Host: providerURL.Host,
Path: "/oauth/authorize",
},
RedeemURL: &url.URL{
Scheme: "http",
Host: providerURL.Host,
Path: "/oauth/token",
},
ProfileURL: &url.URL{
Scheme: "http",
Host: providerURL.Host,
Path: "/api/v1/profile",
},
Scope: "profile.email",
},
EmailAddress: emailAddress,
GroupValidator: func(s string) bool {
return true
},
}
}
func (tp *TestProvider) GetEmailAddress(_ context.Context, _ *sessions.SessionState) (string, error) {
return tp.EmailAddress, nil
}
func (tp *TestProvider) ValidateSession(_ context.Context, _ *sessions.SessionState) bool {
return tp.ValidToken
}
func Test_redeemCode(t *testing.T) {
opts := baseTestOptions()
err := validation.Validate(opts)
assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
_, err = proxy.redeemCode(req)
assert.Equal(t, providers.ErrMissingCode, err)
}
func Test_enrichSession(t *testing.T) {
const (
sessionUser = "Mr Session"
sessionEmail = "session@example.com"
providerEmail = "provider@example.com"
)
testCases := map[string]struct {
session *sessions.SessionState
expectedUser string
expectedEmail string
}{
"Session already has enrichable fields": {
session: &sessions.SessionState{
User: sessionUser,
Email: sessionEmail,
},
expectedUser: sessionUser,
expectedEmail: sessionEmail,
},
"Session is missing Email and GetEmailAddress is implemented": {
session: &sessions.SessionState{
User: sessionUser,
},
expectedUser: sessionUser,
expectedEmail: providerEmail,
},
"Session is missing User and GetUserName is not implemented": {
session: &sessions.SessionState{
Email: sessionEmail,
},
expectedUser: "",
expectedEmail: sessionEmail,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
opts := baseTestOptions()
err := validation.Validate(opts)
assert.NoError(t, err)
// intentionally set after validation.Validate(opts) since it will clobber
// our TestProvider and call `providers.New` defaulting to `providers.GoogleProvider`
opts.SetProvider(NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail))
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
if err != nil {
t.Fatal(err)
}
err = proxy.enrichSessionState(context.Background(), tc.session)
assert.NoError(t, err)
assert.Equal(t, tc.expectedUser, tc.session.User)
assert.Equal(t, tc.expectedEmail, tc.session.Email)
})
}
}
func TestBasicAuthPassword(t *testing.T) {
providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Printf("%#v", r)
var payload string
switch r.URL.Path {
case "/oauth/token":
payload = `{"access_token": "my_auth_token"}`
default:
payload = r.Header.Get("Authorization")
if payload == "" {
payload = "No Authorization header found."
}
}
w.WriteHeader(200)
_, err := w.Write([]byte(payload))
if err != nil {
t.Fatal(err)
}
}))
basicAuthPassword := "This is a secure password"
opts := baseTestOptions()
opts.UpstreamServers = options.UpstreamConfig{
Upstreams: []options.Upstream{
{
ID: providerServer.URL,
Path: "/",
URI: providerServer.URL,
},
},
}
opts.Cookie.Secure = false
opts.InjectRequestHeaders = []options.Header{
{
Name: "Authorization",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "email",
BasicAuthPassword: &options.SecretSource{
Value: []byte(basicAuthPassword),
},
},
},
},
},
}
err := validation.Validate(opts)
assert.NoError(t, err)
providerURL, _ := url.Parse(providerServer.URL)
const emailAddress = "john.doe@example.com"
opts.SetProvider(NewTestProvider(providerURL, emailAddress))
proxy, err := NewOAuthProxy(opts, func(email string) bool {
return email == emailAddress
})
if err != nil {
t.Fatal(err)
}
// Save the required session
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
err = proxy.sessionStore.Save(rw, req, &sessions.SessionState{
Email: emailAddress,
})
assert.NoError(t, err)
// Extract the cookie value to inject into the test request
cookie := rw.Header().Values("Set-Cookie")[0]
req, _ = http.NewRequest("GET", "/", nil)
req.Header.Set("Cookie", cookie)
rw = httptest.NewRecorder()
proxy.ServeHTTP(rw, req)
// The username in the basic auth credentials is expected to be equal to the email address from the
// auth response, so we use the same variable here.
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+basicAuthPassword))
assert.Equal(t, expectedHeader, rw.Body.String())
providerServer.Close()
}
func TestPassGroupsHeadersWithGroups(t *testing.T) {
opts := baseTestOptions()
opts.InjectRequestHeaders = []options.Header{
{
Name: "X-Forwarded-Groups",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "groups",
},
},
},
},
}
err := validation.Validate(opts)
assert.NoError(t, err)
const emailAddress = "john.doe@example.com"
const userName = "9fcab5c9b889a557"
groups := []string{"a", "b"}
created := time.Now()
session := &sessions.SessionState{
User: userName,
Groups: groups,
Email: emailAddress,
AccessToken: "oauth_token",
CreatedAt: &created,
}
proxy, err := NewOAuthProxy(opts, func(email string) bool {
return email == emailAddress
})
assert.NoError(t, err)
// Save the required session
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
err = proxy.sessionStore.Save(rw, req, session)
assert.NoError(t, err)
// Extract the cookie value to inject into the test request
cookie := rw.Header().Values("Set-Cookie")[0]
req, _ = http.NewRequest("GET", "/", nil)
req.Header.Set("Cookie", cookie)
rw = httptest.NewRecorder()
proxy.ServeHTTP(rw, req)
assert.Equal(t, []string{"a,b"}, req.Header["X-Forwarded-Groups"])
}
type PassAccessTokenTest struct {
providerServer *httptest.Server
proxy *OAuthProxy
opts *options.Options
}
type PassAccessTokenTestOptions struct {
PassAccessToken bool
ValidToken bool
ProxyUpstream options.Upstream
}
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) {
patt := &PassAccessTokenTest{}
patt.providerServer = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var payload string
switch r.URL.Path {
case "/oauth/token":
payload = `{"access_token": "my_auth_token"}`
default:
payload = r.Header.Get("X-Forwarded-Access-Token")
if payload == "" {
payload = "No access token found."
}
}
w.WriteHeader(200)
_, err := w.Write([]byte(payload))
if err != nil {
panic(err)
}
}))
patt.opts = baseTestOptions()
patt.opts.UpstreamServers = options.UpstreamConfig{
Upstreams: []options.Upstream{
{
ID: patt.providerServer.URL,
Path: "/",
URI: patt.providerServer.URL,
},
},
}
if opts.ProxyUpstream.ID != "" {
patt.opts.UpstreamServers.Upstreams = append(patt.opts.UpstreamServers.Upstreams, opts.ProxyUpstream)
}
patt.opts.Cookie.Secure = false
if opts.PassAccessToken {
patt.opts.InjectRequestHeaders = []options.Header{
{
Name: "X-Forwarded-Access-Token",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "access_token",
},
},
},
},
}
}
err := validation.Validate(patt.opts)
if err != nil {
return nil, err
}
providerURL, _ := url.Parse(patt.providerServer.URL)
const emailAddress = "michael.bland@gsa.gov"
testProvider := NewTestProvider(providerURL, emailAddress)
testProvider.ValidToken = opts.ValidToken
patt.opts.SetProvider(testProvider)
patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool {
return email == emailAddress
})
if err != nil {
return nil, err
}
return patt, nil
}
func (patTest *PassAccessTokenTest) Close() {
patTest.providerServer.Close()
}
func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie string) {
rw := httptest.NewRecorder()
csrf, err := cookies.NewCSRF(patTest.proxy.CookieOptions)
if err != nil {
panic(err)
}
req, err := http.NewRequest(
http.MethodGet,
fmt.Sprintf(
"/oauth2/callback?code=callback_code&state=%s",
encodeState(csrf.HashOAuthState(), "%2F"),
),
strings.NewReader(""),
)
if err != nil {
return 0, ""
}
// rw is a dummy here, we just want the csrfCookie to add to our req
csrfCookie, err := csrf.SetCookie(httptest.NewRecorder(), req)
if err != nil {
panic(err)
}
req.AddCookie(csrfCookie)
patTest.proxy.ServeHTTP(rw, req)
if len(rw.Header().Values("Set-Cookie")) >= 2 {
cookie = rw.Header().Values("Set-Cookie")[1]
}
return rw.Code, cookie
}
// getEndpointWithCookie makes a requests againt the oauthproxy with passed requestPath
// and cookie and returns body and status code.
func (patTest *PassAccessTokenTest) getEndpointWithCookie(cookie string, endpoint string) (httpCode int, accessToken string) {
cookieName := patTest.proxy.CookieOptions.Name
var value string
keyPrefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, keyPrefix)
if value != field {
break
} else {
value = ""
}
}
if value == "" {
return 0, ""
}
req, err := http.NewRequest("GET", endpoint, strings.NewReader(""))
if err != nil {
return 0, ""
}
req.AddCookie(&http.Cookie{
Name: cookieName,
Value: value,
Path: "/",
Expires: time.Now().Add(time.Duration(24)),
HttpOnly: true,
})
rw := httptest.NewRecorder()
patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
func TestForwardAccessTokenUpstream(t *testing.T) {
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true,
ValidToken: true,
})
if err != nil {
t.Fatal(err)
}
t.Cleanup(patTest.Close)
// A successful validation will redirect and set the auth cookie.
code, cookie := patTest.getCallbackEndpoint()
if code != 302 {
t.Fatalf("expected 302; got %d", code)
}
assert.NotNil(t, cookie)
// Now we make a regular request; the access_token from the cookie is
// forwarded as the "X-Forwarded-Access-Token" header. The token is
// read by the test provider server and written in the response body.
code, payload := patTest.getEndpointWithCookie(cookie, "/")
if code != 200 {
t.Fatalf("expected 200; got %d", code)
}
assert.Equal(t, "my_auth_token", payload)
}
func TestStaticProxyUpstream(t *testing.T) {
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true,
ValidToken: true,
ProxyUpstream: options.Upstream{
ID: "static-proxy",
Path: "/static-proxy",
Static: true,
},
})
if err != nil {
t.Fatal(err)
}
t.Cleanup(patTest.Close)
// A successful validation will redirect and set the auth cookie.
code, cookie := patTest.getCallbackEndpoint()
if code != 302 {
t.Fatalf("expected 302; got %d", code)
}
assert.NotEqual(t, nil, cookie)
// Now we make a regular request against the upstream proxy; And validate
// the returned status code through the static proxy.
code, payload := patTest.getEndpointWithCookie(cookie, "/static-proxy")
if code != 200 {
t.Fatalf("expected 200; got %d", code)
}
assert.Equal(t, "Authenticated", payload)
}
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: false,
ValidToken: true,
})
if err != nil {
t.Fatal(err)
}
t.Cleanup(patTest.Close)
// A successful validation will redirect and set the auth cookie.
code, cookie := patTest.getCallbackEndpoint()
if code != 302 {
t.Fatalf("expected 302; got %d", code)
}
assert.NotEqual(t, nil, cookie)
// Now we make a regular request, but the access token header should
// not be present.
code, payload := patTest.getEndpointWithCookie(cookie, "/")
if code != 200 {
t.Fatalf("expected 200; got %d", code)
}
assert.Equal(t, "No access token found.", payload)
}
func TestSessionValidationFailure(t *testing.T) {
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
ValidToken: false,
})
require.NoError(t, err)
t.Cleanup(patTest.Close)
// An unsuccessful validation will return 403 and not set the auth cookie.
code, cookie := patTest.getCallbackEndpoint()
assert.Equal(t, http.StatusForbidden, code)
assert.Equal(t, "", cookie)
}
type SignInPageTest struct {
opts *options.Options
proxy *OAuthProxy
signInRegexp *regexp.Regexp
signInProviderRegexp *regexp.Regexp
}
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
const signInSkipProvider = `>Found<`
func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) {
var sipTest SignInPageTest
sipTest.opts = baseTestOptions()
sipTest.opts.SkipProviderButton = skipProvider
err := validation.Validate(sipTest.opts)
if err != nil {
return nil, err
}
sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool {
return true
})
if err != nil {
return nil, err
}
sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
return &sipTest, nil
}
func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
sipTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
type AlwaysSuccessfulValidator struct {
}
func (AlwaysSuccessfulValidator) Validate(user, password string) bool {
return true
}
func TestManualSignInStoresUserGroupsInTheSession(t *testing.T) {
userGroups := []string{"somegroup", "someothergroup"}
opts := baseTestOptions()
opts.HtpasswdUserGroups = userGroups
err := validation.Validate(opts)
if err != nil {
t.Fatal(err)
}
proxy, err := NewOAuthProxy(opts, func(email string) bool {
return true
})
if err != nil {
t.Fatal(err)
}
proxy.basicAuthValidator = AlwaysSuccessfulValidator{}
rw := httptest.NewRecorder()
formData := url.Values{}
formData.Set("username", "someuser")
formData.Set("password", "somepass")
signInReq, _ := http.NewRequest(http.MethodPost, "/oauth2/sign_in", strings.NewReader(formData.Encode()))
signInReq.Header.Add("Content-Type", "application/x-www-form-urlencoded")
proxy.ServeHTTP(rw, signInReq)
assert.Equal(t, http.StatusFound, rw.Code)
req, _ := http.NewRequest(http.MethodGet, "/something", strings.NewReader(formData.Encode()))
for _, c := range rw.Result().Cookies() {
req.AddCookie(c)
}
s, err := proxy.sessionStore.Load(req)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, userGroups, s.Groups)
}
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
sipTest, err := NewSignInPageTest(false)
if err != nil {
t.Fatal(err)
}
const endpoint = "/some/random/endpoint"
code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 403, code)
match := sipTest.signInRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body)
}
if match[1] != endpoint {
t.Fatal(`expected redirect to "` + endpoint +
`", but was "` + match[1] + `"`)
}
}
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
sipTest, err := NewSignInPageTest(false)
if err != nil {
t.Fatal(err)
}
code, body := sipTest.GetEndpoint("/oauth2/sign_in")
assert.Equal(t, 200, code)
match := sipTest.signInRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body)
}
if match[1] != "/" {
t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
}
}
func TestSignInPageSkipProvider(t *testing.T) {
sipTest, err := NewSignInPageTest(true)
if err != nil {
t.Fatal(err)
}
endpoint := "/some/random/endpoint"
code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 302, code)
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInSkipProvider + "\nBody:\n" + body)
}
}
func TestSignInPageSkipProviderDirect(t *testing.T) {
sipTest, err := NewSignInPageTest(true)
if err != nil {
t.Fatal(err)
}
endpoint := "/sign_in"
code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 302, code)
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInSkipProvider + "\nBody:\n" + body)
}
}
type ProcessCookieTest struct {
opts *options.Options
proxy *OAuthProxy
rw *httptest.ResponseRecorder
req *http.Request
validateUser bool
}
type ProcessCookieTestOpts struct {
providerValidateCookieResponse bool
}
type OptionsModifier func(*options.Options)
func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
var pcTest ProcessCookieTest
pcTest.opts = baseTestOptions()
for _, modifier := range modifiers {
modifier(pcTest.opts)
}
// First, set the CookieRefresh option so proxy.AesCipher is created,
// needed to encrypt the access_token.
pcTest.opts.Cookie.Refresh = time.Hour
err := validation.Validate(pcTest.opts)
if err != nil {
return nil, err
}
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pcTest.validateUser
})
if err != nil {
return nil, err
}
pcTest.proxy.provider = &TestProvider{
ProviderData: &providers.ProviderData{},
ValidToken: opts.providerValidateCookieResponse,
}
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.Providers[0].AllowedGroups)
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation.
pcTest.proxy.CookieOptions.Refresh = time.Duration(0)
pcTest.rw = httptest.NewRecorder()
pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
pcTest.validateUser = true
return &pcTest, nil
}
func NewProcessCookieTestWithDefaults() (*ProcessCookieTest, error) {
return NewProcessCookieTest(ProcessCookieTestOpts{
providerValidateCookieResponse: true,
})
}
func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
return NewProcessCookieTest(ProcessCookieTestOpts{
providerValidateCookieResponse: true,
}, modifiers...)
}
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error {
err := p.proxy.SaveSession(p.rw, p.req, s)
if err != nil {
return err
}
for _, cookie := range p.rw.Result().Cookies() {
p.req.AddCookie(cookie)
}
return nil
}
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) {
return p.proxy.LoadCookiedSession(p.req)
}
func TestLoadCookiedSession(t *testing.T) {
pcTest, err := NewProcessCookieTestWithDefaults()
if err != nil {
t.Fatal(err)
}
created := time.Now()
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created}
err = pcTest.SaveSession(startSession)
assert.NoError(t, err)
session, err := pcTest.LoadCookiedSession()
if err != nil {
t.Fatal(err)
}
assert.Equal(t, startSession.Email, session.Email)
assert.Equal(t, "", session.User)
assert.Equal(t, startSession.AccessToken, session.AccessToken)
}
func TestProcessCookieNoCookieError(t *testing.T) {
pcTest, err := NewProcessCookieTestWithDefaults()
if err != nil {
t.Fatal(err)
}
session, err := pcTest.LoadCookiedSession()
assert.Error(t, err, "cookie \"_oauth2_proxy\" not present")
if session != nil {
t.Errorf("expected nil session. got %#v", session)
}
}
func TestProcessCookieRefreshNotSet(t *testing.T) {
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
opts.Cookie.Expire = time.Duration(23) * time.Hour
})
if err != nil {
t.Fatal(err)
}
reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
err = pcTest.SaveSession(startSession)
assert.NoError(t, err)
session, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err)
if session.Age() < time.Duration(-2)*time.Hour {
t.Errorf("cookie too young %v", session.Age())
}
assert.Equal(t, startSession.Email, session.Email)
}
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
opts.Cookie.Expire = time.Duration(24) * time.Hour
})
if err != nil {
t.Fatal(err)
}
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
err = pcTest.SaveSession(startSession)
assert.NoError(t, err)
session, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expected nil session %#v", session)
}
}
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
opts.Cookie.Expire = time.Duration(24) * time.Hour
})
if err != nil {
t.Fatal(err)
}
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
err = pcTest.SaveSession(startSession)
assert.NoError(t, err)
pcTest.proxy.CookieOptions.Refresh = time.Hour
session, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expected nil session %#v", session)
}
}
func NewUserInfoEndpointTest() (*ProcessCookieTest, error) {
pcTest, err := NewProcessCookieTestWithDefaults()
if err != nil {
return nil, err
}
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/userinfo", nil)
return pcTest, nil
}
func TestUserInfoEndpointAccepted(t *testing.T) {
testCases := []struct {
name string
session *sessions.SessionState
expectedResponse string
}{
{
name: "Full session",
session: &sessions.SessionState{
User: "john.doe",
Email: "john.doe@example.com",
Groups: []string{"example", "groups"},
AccessToken: "my_access_token",
},
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"]}\n",
},
{
name: "Minimal session",
session: &sessions.SessionState{
User: "john.doe",
Email: "john.doe@example.com",
Groups: []string{"example", "groups"},
},
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"]}\n",
},
{
name: "No groups",
session: &sessions.SessionState{
User: "john.doe",
Email: "john.doe@example.com",
AccessToken: "my_access_token",
},
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\"}\n",
},
{
name: "With Preferred Username",
session: &sessions.SessionState{
User: "john.doe",
PreferredUsername: "john",
Email: "john.doe@example.com",
Groups: []string{"example", "groups"},
AccessToken: "my_access_token",
},
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"],\"preferredUsername\":\"john\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
test, err := NewUserInfoEndpointTest()
if err != nil {
t.Fatal(err)
}
err = test.SaveSession(tc.session)
assert.NoError(t, err)
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusOK, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, tc.expectedResponse, string(bodyBytes))
})
}
}
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
test, err := NewUserInfoEndpointTest()
if err != nil {
t.Fatal(err)
}
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
}
func TestEncodedUrlsStayEncoded(t *testing.T) {
encodeTest, err := NewSignInPageTest(false)
if err != nil {
t.Fatal(err)
}
code, _ := encodeTest.GetEndpoint("/%2F/test1/%2F/test2")
assert.Equal(t, 403, code)
}
func NewAuthOnlyEndpointTest(querystring string, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...)
if err != nil {
return nil, err
}
pcTest.req, _ = http.NewRequest(
"GET",
fmt.Sprintf("%s/auth%s", pcTest.opts.ProxyPrefix, querystring),
nil)
return pcTest, nil
}
func TestAuthOnlyEndpointAccepted(t *testing.T) {
test, err := NewAuthOnlyEndpointTest("")
if err != nil {
t.Fatal(err)
}
created := time.Now()
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
err = test.SaveSession(startSession)
assert.NoError(t, err)
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusAccepted, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "", string(bodyBytes))
}
func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
test, err := NewAuthOnlyEndpointTest("")
if err != nil {
t.Fatal(err)
}
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
}
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
opts.Cookie.Expire = time.Duration(24) * time.Hour
})
if err != nil {
t.Fatal(err)
}
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
err = test.SaveSession(startSession)
assert.NoError(t, err)
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
}
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test, err := NewAuthOnlyEndpointTest("")
if err != nil {
t.Fatal(err)
}
created := time.Now()
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
err = test.SaveSession(startSession)
assert.NoError(t, err)
test.validateUser = false
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
}
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
var pcTest ProcessCookieTest
pcTest.opts = baseTestOptions()
pcTest.opts.InjectResponseHeaders = []options.Header{
{
Name: "X-Auth-Request-User",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "user",
},
},
},
},
{
Name: "X-Auth-Request-Email",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "email",
},
},
},
},
{
Name: "X-Auth-Request-Groups",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "groups",
},
},
},
},
{
Name: "X-Forwarded-Preferred-Username",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "preferred_username",
},
},
},
},
}
pcTest.opts.Providers[0].AllowedGroups = []string{"oauth_groups"}
err := validation.Validate(pcTest.opts)
assert.NoError(t, err)
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pcTest.validateUser
})
if err != nil {
t.Fatal(err)
}
pcTest.proxy.provider = &TestProvider{
ProviderData: &providers.ProviderData{},
ValidToken: true,
}
pcTest.validateUser = true
pcTest.rw = httptest.NewRecorder()
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
created := time.Now()
startSession := &sessions.SessionState{
User: "oauth_user", Groups: []string{"oauth_groups"}, Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
err = pcTest.SaveSession(startSession)
assert.NoError(t, err)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
assert.Equal(t, "oauth_user", pcTest.rw.Header().Get("X-Auth-Request-User"))
assert.Equal(t, startSession.Groups, pcTest.rw.Header().Values("X-Auth-Request-Groups"))
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Get("X-Auth-Request-Email"))
}
func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
var pcTest ProcessCookieTest
pcTest.opts = baseTestOptions()
pcTest.opts.InjectResponseHeaders = []options.Header{
{
Name: "X-Auth-Request-User",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "user",
},
},
},
},
{
Name: "X-Auth-Request-Email",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "email",
},
},
},
},
{
Name: "X-Auth-Request-Groups",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "groups",
},
},
},
},
{
Name: "X-Forwarded-Preferred-Username",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "preferred_username",
},
},
},
},
{
Name: "Authorization",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "user",
BasicAuthPassword: &options.SecretSource{
Value: []byte("This is a secure password"),
},
},
},
},
},
}
err := validation.Validate(pcTest.opts)
assert.NoError(t, err)
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pcTest.validateUser
})
if err != nil {
t.Fatal(err)
}
pcTest.proxy.provider = &TestProvider{
ProviderData: &providers.ProviderData{},
ValidToken: true,
}
pcTest.validateUser = true
pcTest.rw = httptest.NewRecorder()
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
created := time.Now()
startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
err = pcTest.SaveSession(startSession)
assert.NoError(t, err)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("oauth_user:This is a secure password"))
assert.Equal(t, expectedHeader, pcTest.rw.Header().Values("Authorization")[0])
}
func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
var pcTest ProcessCookieTest
pcTest.opts = baseTestOptions()
pcTest.opts.InjectResponseHeaders = []options.Header{
{
Name: "X-Auth-Request-User",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "user",
},
},
},
},
{
Name: "X-Auth-Request-Email",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "email",
},
},
},
},
{
Name: "X-Auth-Request-Groups",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "groups",
},
},
},
},
{
Name: "X-Forwarded-Preferred-Username",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "preferred_username",
},
},
},
},
}
err := validation.Validate(pcTest.opts)
assert.NoError(t, err)
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pcTest.validateUser
})
if err != nil {
t.Fatal(err)
}
pcTest.proxy.provider = &TestProvider{
ProviderData: &providers.ProviderData{},
ValidToken: true,
}
pcTest.validateUser = true
pcTest.rw = httptest.NewRecorder()
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
created := time.Now()
startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
err = pcTest.SaveSession(startSession)
assert.NoError(t, err)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
assert.Equal(t, 0, len(pcTest.rw.Header().Values("Authorization")), "should not have Authorization header entries")
}
func TestAuthSkippedForPreflightRequests(t *testing.T) {
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
_, err := w.Write([]byte("response"))
if err != nil {
t.Fatal(err)
}
}))
t.Cleanup(upstreamServer.Close)
opts := baseTestOptions()
opts.UpstreamServers = options.UpstreamConfig{
Upstreams: []options.Upstream{
{
ID: upstreamServer.URL,
Path: "/",
URI: upstreamServer.URL,
},
},
}
opts.SkipAuthPreflight = true
err := validation.Validate(opts)
assert.NoError(t, err)
upstreamURL, _ := url.Parse(upstreamServer.URL)
opts.SetProvider(NewTestProvider(upstreamURL, ""))
proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
if err != nil {
t.Fatal(err)
}
rw := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
proxy.ServeHTTP(rw, req)
assert.Equal(t, 200, rw.Code)
assert.Equal(t, "response", rw.Body.String())
}
type SignatureAuthenticator struct {
auth hmacauth.HmacAuth
}
func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) {
result, headerSig, computedSig := v.auth.AuthenticateRequest(r)
var msg string
switch result {
case hmacauth.ResultNoSignature:
msg = "no signature received"
case hmacauth.ResultMatch:
msg = "signatures match"
case hmacauth.ResultMismatch:
msg = fmt.Sprintf(
"signatures do not match:\n received: %s\n computed: %s",
headerSig,
computedSig)
default:
panic("unknown result value: " + result.String())
}
_, err := w.Write([]byte(msg))
if err != nil {
panic(err)
}
}
type SignatureTest struct {
opts *options.Options
upstream *httptest.Server
upstreamHost string
provider *httptest.Server
header http.Header
rw *httptest.ResponseRecorder
authenticator *SignatureAuthenticator
}
func NewSignatureTest() (*SignatureTest, error) {
opts := baseTestOptions()
opts.EmailDomains = []string{"acm.org"}
authenticator := &SignatureAuthenticator{}
upstreamServer := httptest.NewServer(
http.HandlerFunc(authenticator.Authenticate))
upstreamURL, err := url.Parse(upstreamServer.URL)
if err != nil {
return nil, err
}
opts.UpstreamServers = options.UpstreamConfig{
Upstreams: []options.Upstream{
{
ID: upstreamServer.URL,
Path: "/",
URI: upstreamServer.URL,
},
},
}
providerHandler := func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte(`{"access_token": "my_auth_token"}`))
if err != nil {
panic(err)
}
}
provider := httptest.NewServer(http.HandlerFunc(providerHandler))
providerURL, err := url.Parse(provider.URL)
if err != nil {
return nil, err
}
opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org"))
return &SignatureTest{
opts,
upstreamServer,
upstreamURL.Host,
provider,
make(http.Header),
httptest.NewRecorder(),
authenticator,
}, nil
}
func (st *SignatureTest) Close() {
st.provider.Close()
st.upstream.Close()
}
// fakeNetConn simulates an http.Request.Body buffer that will be consumed
// when it is read by the hmacauth.HmacAuth if not handled properly. See:
// https://github.com/18F/hmacauth/pull/4
type fakeNetConn struct {
reqBody string
}
func (fnc *fakeNetConn) Read(p []byte) (n int, err error) {
if bodyLen := len(fnc.reqBody); bodyLen != 0 {
copy(p, fnc.reqBody)
fnc.reqBody = ""
return bodyLen, io.EOF
}
return 0, io.EOF
}
func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) error {
err := validation.Validate(st.opts)
if err != nil {
return err
}
proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true })
if err != nil {
return err
}
var bodyBuf io.ReadCloser
if body != "" {
bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body})
}
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
req.Header = st.header
state := &sessions.SessionState{
Email: "mbland@acm.org", AccessToken: "my_access_token"}
err = proxy.SaveSession(st.rw, req, state)
if err != nil {
return err
}
for _, c := range st.rw.Result().Cookies() {
req.AddCookie(c)
}
// This is used by the upstream to validate the signature.
st.authenticator.auth = hmacauth.NewHmacAuth(
crypto.SHA1, []byte(key), upstream.SignatureHeader, upstream.SignatureHeaders)
proxy.ServeHTTP(st.rw, req)
return nil
}
func TestRequestSignature(t *testing.T) {
testCases := map[string]struct {
method string
body string
key string
resp string
}{
"No request signature": {
method: "GET",
body: "",
key: "",
resp: "no signature received",
},
"Get request": {
method: "GET",
body: "",
key: "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d",
resp: "signatures match",
},
"Post request": {
method: "POST",
body: `{ "hello": "world!" }`,
key: "d90df39e2d19282840252612dd7c81421a372f61",
resp: "signatures match",
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
st, err := NewSignatureTest()
if err != nil {
t.Fatal(err)
}
t.Cleanup(st.Close)
if tc.key != "" {
st.opts.SignatureKey = fmt.Sprintf("sha1:%s", tc.key)
}
err = st.MakeRequestWithExpectedKey(tc.method, tc.body, tc.key)
assert.NoError(t, err)
assert.Equal(t, 200, st.rw.Code)
assert.Equal(t, tc.resp, st.rw.Body.String())
})
}
}
type ajaxRequestTest struct {
opts *options.Options
proxy *OAuthProxy
}
func newAjaxRequestTest(forceJSONErrors bool) (*ajaxRequestTest, error) {
test := &ajaxRequestTest{}
test.opts = baseTestOptions()
test.opts.ForceJSONErrors = forceJSONErrors
err := validation.Validate(test.opts)
if err != nil {
return nil, err
}
test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool {
return true
})
if err != nil {
return nil, err
}
return test, nil
}
func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, []byte, error) {
rw := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader(""))
if err != nil {
return 0, nil, nil, err
}
req.Header = header
test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Header(), rw.Body.Bytes(), nil
}
func testAjaxUnauthorizedRequest(t *testing.T, header http.Header, forceJSONErrors bool) {
test, err := newAjaxRequestTest(forceJSONErrors)
if err != nil {
t.Fatal(err)
}
endpoint := "/test"
code, rh, body, err := test.getEndpoint(endpoint, header)
assert.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, code)
mime := rh.Get("Content-Type")
assert.Equal(t, applicationJSON, mime)
assert.Equal(t, []byte("{}"), body)
}
func TestAjaxUnauthorizedRequest1(t *testing.T) {
header := make(http.Header)
header.Add("accept", applicationJSON)
testAjaxUnauthorizedRequest(t, header, false)
}
func TestAjaxUnauthorizedRequest2(t *testing.T) {
header := make(http.Header)
header.Add("Accept", applicationJSON)
testAjaxUnauthorizedRequest(t, header, false)
}
func TestAjaxUnauthorizedRequestAccept1(t *testing.T) {
header := make(http.Header)
header.Add("Accept", "application/json, text/plain, */*")
testAjaxUnauthorizedRequest(t, header, false)
}
func TestForceJSONErrorsUnauthorizedRequest(t *testing.T) {
testAjaxUnauthorizedRequest(t, nil, true)
}
func TestAjaxForbiddendRequest(t *testing.T) {
test, err := newAjaxRequestTest(false)
if err != nil {
t.Fatal(err)
}
endpoint := "/test"
header := make(http.Header)
code, rh, _, err := test.getEndpoint(endpoint, header)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, code)
mime := rh.Get("Content-Type")
assert.NotEqual(t, applicationJSON, mime)
}
func TestClearSplitCookie(t *testing.T) {
opts := baseTestOptions()
opts.Cookie.Secret = base64CookieSecret
opts.Cookie.Name = "oauth2"
opts.Cookie.Domains = []string{"abc"}
err := validation.Validate(opts)
assert.NoError(t, err)
store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie)
if err != nil {
t.Fatal(err)
}
p := OAuthProxy{CookieOptions: &opts.Cookie, sessionStore: store}
var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)
req.AddCookie(&http.Cookie{
Name: "test1",
Value: "test1",
})
req.AddCookie(&http.Cookie{
Name: "oauth2_0",
Value: "oauth2_0",
})
req.AddCookie(&http.Cookie{
Name: "oauth2_1",
Value: "oauth2_1",
})
err = p.ClearSessionCookie(rw, req)
assert.NoError(t, err)
header := rw.Header()
assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries")
}
func TestClearSingleCookie(t *testing.T) {
opts := baseTestOptions()
opts.Cookie.Name = "oauth2"
opts.Cookie.Domains = []string{"abc"}
store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie)
if err != nil {
t.Fatal(err)
}
p := OAuthProxy{CookieOptions: &opts.Cookie, sessionStore: store}
var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)
req.AddCookie(&http.Cookie{
Name: "test1",
Value: "test1",
})
req.AddCookie(&http.Cookie{
Name: "oauth2",
Value: "oauth2",
})
err = p.ClearSessionCookie(rw, req)
assert.NoError(t, err)
header := rw.Header()
assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries")
}
type NoOpKeySet struct {
}
func (NoOpKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
splitStrings := strings.Split(jwt, ".")
payloadString := splitStrings[1]
return base64.RawURLEncoding.DecodeString(payloadString)
}
func TestGetJwtSession(t *testing.T) {
/* token payload:
{
"sub": "1234567890",
"aud": "https://test.myapp.com",
"name": "John Doe",
"email": "john@example.com",
"iss": "https://issuer.example.com",
"iat": 1553691215,
"exp": 1912151821
}
*/
goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
"eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" +
"WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" +
"E1LCJleHAiOjE5MTIxNTE4MjF9." +
"rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" +
"OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8"
keyset := NoOpKeySet{}
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true,
SkipClientIDCheck: true})
verificationOptions := &internaloidc.IDTokenVerificationOptions{
AudienceClaims: []string{"aud"},
ClientID: "https://test.myapp.com",
ExtraAudiences: []string{},
}
internalVerifier := internaloidc.NewVerifier(verifier, verificationOptions)
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
opts.InjectRequestHeaders = []options.Header{
{
Name: "Authorization",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "id_token",
Prefix: "Bearer ",
},
},
},
},
{
Name: "X-Forwarded-User",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "user",
},
},
},
},
{
Name: "X-Forwarded-Email",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "email",
},
},
},
},
}
opts.InjectResponseHeaders = []options.Header{
{
Name: "Authorization",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "id_token",
Prefix: "Bearer ",
},
},
},
},
{
Name: "X-Auth-Request-User",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "user",
},
},
},
},
{
Name: "X-Auth-Request-Email",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "email",
},
},
},
},
}
opts.SkipJwtBearerTokens = true
opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), internalVerifier))
})
if err != nil {
t.Fatal(err)
}
tp, _ := test.proxy.provider.(*TestProvider)
tp.GroupValidator = func(s string) bool {
return true
}
authHeader := fmt.Sprintf("Bearer %s", goodJwt)
test.req.Header = map[string][]string{
"Authorization": {authHeader},
}
test.proxy.ServeHTTP(test.rw, test.req)
if test.rw.Code >= 400 {
t.Fatalf("expected 3xx got %d", test.rw.Code)
}
// Check PassAuthorization, should overwrite Basic header
assert.Equal(t, test.req.Header.Get("Authorization"), authHeader)
assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "1234567890")
assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com")
// SetAuthorization and SetXAuthRequest
assert.Equal(t, test.rw.Header().Get("Authorization"), authHeader)
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "1234567890")
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
}
func Test_prepareNoCache(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
prepareNoCache(w)
})
mux := http.NewServeMux()
mux.Handle("/", handler)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mux.ServeHTTP(rec, req)
for k, v := range noCacheHeaders {
assert.Equal(t, rec.Header().Get(k), v)
}
}
func Test_noCacheHeaders(t *testing.T) {
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("upstream"))
if err != nil {
t.Error(err)
}
}))
t.Cleanup(upstreamServer.Close)
opts := baseTestOptions()
opts.UpstreamServers = options.UpstreamConfig{
Upstreams: []options.Upstream{
{
ID: upstreamServer.URL,
Path: "/",
URI: upstreamServer.URL,
},
},
}
opts.SkipAuthRegex = []string{".*"}
err := validation.Validate(opts)
assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
if err != nil {
t.Fatal(err)
}
t.Run("not exist in response from upstream", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/upstream", nil)
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "upstream", rec.Body.String())
// checking noCacheHeaders does not exists in response headers from upstream
for k := range noCacheHeaders {
assert.Equal(t, "", rec.Header().Get(k))
}
})
t.Run("has no-cache", func(t *testing.T) {
tests := []struct {
path string
hasNoCache bool
}{
{
path: "/oauth2/sign_in",
hasNoCache: true,
},
{
path: "/oauth2/sign_out",
hasNoCache: true,
},
{
path: "/oauth2/start",
hasNoCache: true,
},
{
path: "/oauth2/callback",
hasNoCache: true,
},
{
path: "/oauth2/auth",
hasNoCache: false,
},
{
path: "/oauth2/userinfo",
hasNoCache: true,
},
{
path: "/upstream",
hasNoCache: false,
},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, tt.path, nil)
proxy.ServeHTTP(rec, req)
cacheControl := rec.Result().Header.Get("Cache-Control")
if tt.hasNoCache != (strings.Contains(cacheControl, "no-cache")) {
t.Errorf(`unexpected "Cache-Control" header: %s`, cacheControl)
}
})
}
})
}
func baseTestOptions() *options.Options {
opts := options.NewOptions()
opts.Cookie.Secret = rawCookieSecret
opts.Providers[0].ID = "providerID"
opts.Providers[0].ClientID = clientID
opts.Providers[0].ClientSecret = clientSecret
opts.EmailDomains = []string{"*"}
// Default injected headers for legacy configuration
opts.InjectRequestHeaders = []options.Header{
{
Name: "Authorization",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "user",
BasicAuthPassword: &options.SecretSource{
Value: []byte(base64.StdEncoding.EncodeToString([]byte("This is a secure password"))),
},
},
},
},
},
{
Name: "X-Forwarded-User",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "user",
},
},
},
},
{
Name: "X-Forwarded-Email",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "email",
},
},
},
},
}
return opts
}
func TestTrustedIPs(t *testing.T) {
tests := []struct {
name string
trustedIPs []string
reverseProxy bool
realClientIPHeader string
req *http.Request
expectTrusted bool
}{
// Check unconfigured behavior.
{
name: "Default",
trustedIPs: nil,
reverseProxy: false,
realClientIPHeader: "X-Real-IP", // Default value
req: func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
return req
}(),
expectTrusted: false,
},
// Check using req.RemoteAddr (Options.ReverseProxy == false).
{
name: "WithRemoteAddr",
trustedIPs: []string{"127.0.0.1"},
reverseProxy: false,
realClientIPHeader: "X-Real-IP", // Default value
req: func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "127.0.0.1:43670"
return req
}(),
expectTrusted: true,
},
// Check ignores req.RemoteAddr match when behind a reverse proxy / missing header.
{
name: "IgnoresRemoteAddrInReverseProxyMode",
trustedIPs: []string{"127.0.0.1"},
reverseProxy: true,
realClientIPHeader: "X-Real-IP", // Default value
req: func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "127.0.0.1:44324"
return req
}(),
expectTrusted: false,
},
// Check successful trusting of localhost in IPv4.
{
name: "TrustsLocalhostInReverseProxyMode",
trustedIPs: []string{"127.0.0.0/8", "::1"},
reverseProxy: true,
realClientIPHeader: "X-Forwarded-For",
req: func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Forwarded-For", "127.0.0.1")
return req
}(),
expectTrusted: true,
},
// Check successful trusting of localhost in IPv6.
{
name: "TrustsIP6LocalostInReverseProxyMode",
trustedIPs: []string{"127.0.0.0/8", "::1"},
reverseProxy: true,
realClientIPHeader: "X-Forwarded-For",
req: func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Forwarded-For", "::1")
return req
}(),
expectTrusted: true,
},
// Check does not trust random IPv4 address.
{
name: "DoesNotTrustRandomIP4Address",
trustedIPs: []string{"127.0.0.0/8", "::1"},
reverseProxy: true,
realClientIPHeader: "X-Forwarded-For",
req: func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Forwarded-For", "12.34.56.78")
return req
}(),
expectTrusted: false,
},
// Check does not trust random IPv6 address.
{
name: "DoesNotTrustRandomIP6Address",
trustedIPs: []string{"127.0.0.0/8", "::1"},
reverseProxy: true,
realClientIPHeader: "X-Forwarded-For",
req: func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Forwarded-For", "::2")
return req
}(),
expectTrusted: false,
},
// Check respects correct header.
{
name: "RespectsCorrectHeaderInReverseProxyMode",
trustedIPs: []string{"127.0.0.0/8", "::1"},
reverseProxy: true,
realClientIPHeader: "X-Forwarded-For",
req: func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Real-IP", "::1")
return req
}(),
expectTrusted: false,
},
// Check doesn't trust if garbage is provided.
{
name: "DoesNotTrustGarbageInReverseProxyMode",
trustedIPs: []string{"127.0.0.0/8", "::1"},
reverseProxy: true,
realClientIPHeader: "X-Forwarded-For",
req: func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-Forwarded-For", "adsfljk29242as!!")
return req
}(),
expectTrusted: false,
},
// Check doesn't trust if garbage is provided (no reverse-proxy).
{
name: "DoesNotTrustGarbage",
trustedIPs: []string{"127.0.0.0/8", "::1"},
reverseProxy: false,
realClientIPHeader: "X-Real-IP",
req: func() *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "adsfljk29242as!!"
return req
}(),
expectTrusted: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := baseTestOptions()
opts.UpstreamServers = options.UpstreamConfig{
Upstreams: []options.Upstream{
{
ID: "static",
Path: "/",
Static: true,
},
},
}
opts.TrustedIPs = tt.trustedIPs
opts.ReverseProxy = tt.reverseProxy
opts.RealClientIPHeader = tt.realClientIPHeader
err := validation.Validate(opts)
assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
assert.NoError(t, err)
rw := httptest.NewRecorder()
proxy.ServeHTTP(rw, tt.req)
if tt.expectTrusted {
assert.Equal(t, 200, rw.Code)
} else {
assert.Equal(t, 403, rw.Code)
}
})
}
}
func Test_buildRoutesAllowlist(t *testing.T) {
type expectedAllowedRoute struct {
method string
regexString string
}
testCases := []struct {
name string
skipAuthRegex []string
skipAuthRoutes []string
expectedRoutes []expectedAllowedRoute
shouldError bool
}{
{
name: "No skip auth configured",
skipAuthRegex: []string{},
skipAuthRoutes: []string{},
expectedRoutes: []expectedAllowedRoute{},
shouldError: false,
},
{
name: "Only skipAuthRegex configured",
skipAuthRegex: []string{
"^/foo/bar",
"^/baz/[0-9]+/thing",
},
skipAuthRoutes: []string{},
expectedRoutes: []expectedAllowedRoute{
{
method: "",
regexString: "^/foo/bar",
},
{
method: "",
regexString: "^/baz/[0-9]+/thing",
},
},
shouldError: false,
},
{
name: "Only skipAuthRoutes configured",
skipAuthRegex: []string{},
skipAuthRoutes: []string{
"GET=^/foo/bar",
"POST=^/baz/[0-9]+/thing",
"^/all/methods$",
"WEIRD=^/methods/are/allowed",
"PATCH=/second/equals?are=handled&just=fine",
},
expectedRoutes: []expectedAllowedRoute{
{
method: "GET",
regexString: "^/foo/bar",
},
{
method: "POST",
regexString: "^/baz/[0-9]+/thing",
},
{
method: "",
regexString: "^/all/methods$",
},
{
method: "WEIRD",
regexString: "^/methods/are/allowed",
},
{
method: "PATCH",
regexString: "/second/equals?are=handled&just=fine",
},
},
shouldError: false,
},
{
name: "Both skipAuthRegexes and skipAuthRoutes configured",
skipAuthRegex: []string{
"^/foo/bar/regex",
"^/baz/[0-9]+/thing/regex",
},
skipAuthRoutes: []string{
"GET=^/foo/bar",
"POST=^/baz/[0-9]+/thing",
"^/all/methods$",
},
expectedRoutes: []expectedAllowedRoute{
{
method: "",
regexString: "^/foo/bar/regex",
},
{
method: "",
regexString: "^/baz/[0-9]+/thing/regex",
},
{
method: "GET",
regexString: "^/foo/bar",
},
{
method: "POST",
regexString: "^/baz/[0-9]+/thing",
},
{
method: "",
regexString: "^/all/methods$",
},
},
shouldError: false,
},
{
name: "Invalid skipAuthRegex entry",
skipAuthRegex: []string{
"^/foo/bar",
"^/baz/[0-9]+/thing",
"(bad[regex",
},
skipAuthRoutes: []string{},
expectedRoutes: []expectedAllowedRoute{},
shouldError: true,
},
{
name: "Invalid skipAuthRoutes entry",
skipAuthRegex: []string{},
skipAuthRoutes: []string{
"GET=^/foo/bar",
"POST=^/baz/[0-9]+/thing",
"^/all/methods$",
"PUT=(bad[regex",
},
expectedRoutes: []expectedAllowedRoute{},
shouldError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
opts := &options.Options{
SkipAuthRegex: tc.skipAuthRegex,
SkipAuthRoutes: tc.skipAuthRoutes,
}
routes, err := buildRoutesAllowlist(opts)
if tc.shouldError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
for i, route := range routes {
assert.Greater(t, len(tc.expectedRoutes), i)
assert.Equal(t, route.method, tc.expectedRoutes[i].method)
assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString)
}
})
}
}
func TestAllowedRequest(t *testing.T) {
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
_, err := w.Write([]byte("Allowed Request"))
if err != nil {
t.Fatal(err)
}
}))
t.Cleanup(upstreamServer.Close)
opts := baseTestOptions()
opts.UpstreamServers = options.UpstreamConfig{
Upstreams: []options.Upstream{
{
ID: upstreamServer.URL,
Path: "/",
URI: upstreamServer.URL,
},
},
}
opts.SkipAuthRegex = []string{
"^/skip/auth/regex$",
}
opts.SkipAuthRoutes = []string{
"GET=^/skip/auth/routes/get",
}
err := validation.Validate(opts)
assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
if err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
method string
url string
allowed bool
}{
{
name: "Regex GET allowed",
method: "GET",
url: "/skip/auth/regex",
allowed: true,
},
{
name: "Regex POST allowed ",
method: "POST",
url: "/skip/auth/regex",
allowed: true,
},
{
name: "Regex denied",
method: "GET",
url: "/wrong/denied",
allowed: false,
},
{
name: "Route allowed",
method: "GET",
url: "/skip/auth/routes/get",
allowed: true,
},
{
name: "Route denied with wrong method",
method: "PATCH",
url: "/skip/auth/routes/get",
allowed: false,
},
{
name: "Route denied with wrong path",
method: "GET",
url: "/skip/auth/routes/wrong/path",
allowed: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest(tc.method, tc.url, nil)
assert.NoError(t, err)
assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req))
rw := httptest.NewRecorder()
proxy.ServeHTTP(rw, req)
if tc.allowed {
assert.Equal(t, 200, rw.Code)
assert.Equal(t, "Allowed Request", rw.Body.String())
} else {
assert.Equal(t, 403, rw.Code)
}
})
}
}
func TestProxyAllowedGroups(t *testing.T) {
tests := []struct {
name string
allowedGroups []string
groups []string
expectUnauthorized bool
}{
{"NoAllowedGroups", []string{}, []string{}, false},
{"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false},
{"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false},
{"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
emailAddress := "test"
created := time.Now()
session := &sessions.SessionState{
Groups: tt.groups,
Email: emailAddress,
AccessToken: "oauth_token",
CreatedAt: &created,
}
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
t.Cleanup(upstreamServer.Close)
test, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
opts.Providers[0].AllowedGroups = tt.allowedGroups
opts.UpstreamServers = options.UpstreamConfig{
Upstreams: []options.Upstream{
{
ID: upstreamServer.URL,
Path: "/",
URI: upstreamServer.URL,
},
},
}
})
if err != nil {
t.Fatal(err)
}
test.req, _ = http.NewRequest("GET", "/", nil)
test.req.Header.Add("accept", applicationJSON)
err = test.SaveSession(session)
assert.NoError(t, err)
test.proxy.ServeHTTP(test.rw, test.req)
if tt.expectUnauthorized {
assert.Equal(t, http.StatusForbidden, test.rw.Code)
} else {
assert.Equal(t, http.StatusOK, test.rw.Code)
}
})
}
}
func TestAuthOnlyAllowedGroups(t *testing.T) {
testCases := []struct {
name string
allowedGroups []string
groups []string
querystring string
expectedStatusCode int
}{
{
name: "NoAllowedGroups",
allowedGroups: []string{},
groups: []string{},
querystring: "",
expectedStatusCode: http.StatusAccepted,
},
{
name: "NoAllowedGroupsUserHasGroups",
allowedGroups: []string{},
groups: []string{"a", "b"},
querystring: "",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserInAllowedGroup",
allowedGroups: []string{"a"},
groups: []string{"a", "b"},
querystring: "",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserNotInAllowedGroup",
allowedGroups: []string{"a"},
groups: []string{"c"},
querystring: "",
expectedStatusCode: http.StatusUnauthorized,
},
{
name: "UserInQuerystringGroup",
allowedGroups: []string{"a", "b"},
groups: []string{"a", "c"},
querystring: "?allowed_groups=a",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserInMultiParamQuerystringGroup",
allowedGroups: []string{"a", "b"},
groups: []string{"b"},
querystring: "?allowed_groups=a&allowed_groups=b,d",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserInOnlyQuerystringGroup",
allowedGroups: []string{},
groups: []string{"a", "c"},
querystring: "?allowed_groups=a,b",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserInDelimitedQuerystringGroup",
allowedGroups: []string{"a", "b", "c"},
groups: []string{"c"},
querystring: "?allowed_groups=a,c",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserNotInQuerystringGroup",
allowedGroups: []string{},
groups: []string{"c"},
querystring: "?allowed_groups=a,b",
expectedStatusCode: http.StatusForbidden,
},
{
name: "UserInConfigGroupNotInQuerystringGroup",
allowedGroups: []string{"a", "b", "c"},
groups: []string{"c"},
querystring: "?allowed_groups=a,b",
expectedStatusCode: http.StatusForbidden,
},
{
name: "UserInQuerystringGroupNotInConfigGroup",
allowedGroups: []string{"a", "b"},
groups: []string{"c"},
querystring: "?allowed_groups=b,c",
expectedStatusCode: http.StatusUnauthorized,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
emailAddress := "test"
created := time.Now()
session := &sessions.SessionState{
Groups: tc.groups,
Email: emailAddress,
AccessToken: "oauth_token",
CreatedAt: &created,
}
test, err := NewAuthOnlyEndpointTest(tc.querystring, func(opts *options.Options) {
opts.Providers[0].AllowedGroups = tc.allowedGroups
})
if err != nil {
t.Fatal(err)
}
err = test.SaveSession(session)
assert.NoError(t, err)
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, tc.expectedStatusCode, test.rw.Code)
})
}
}
func TestAuthOnlyAllowedGroupsWithSkipMethods(t *testing.T) {
testCases := []struct {
name string
groups []string
method string
ip string
withSession bool
expectedStatusCode int
}{
{
name: "UserWithGroupSkipAuthPreflight",
groups: []string{"a", "c"},
method: "OPTIONS",
ip: "1.2.3.5:43670",
withSession: true,
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserWithGroupTrustedIp",
groups: []string{"a", "c"},
method: "GET",
ip: "1.2.3.4:43670",
withSession: true,
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserWithoutGroupSkipAuthPreflight",
groups: []string{"c"},
method: "OPTIONS",
ip: "1.2.3.5:43670",
withSession: true,
expectedStatusCode: http.StatusForbidden,
},
{
name: "UserWithoutGroupTrustedIp",
groups: []string{"c"},
method: "GET",
ip: "1.2.3.4:43670",
withSession: true,
expectedStatusCode: http.StatusForbidden,
},
{
name: "UserWithoutSessionSkipAuthPreflight",
method: "OPTIONS",
ip: "1.2.3.5:43670",
withSession: false,
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserWithoutSessionTrustedIp",
method: "GET",
ip: "1.2.3.4:43670",
withSession: false,
expectedStatusCode: http.StatusAccepted,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
test, err := NewAuthOnlyEndpointTest("?allowed_groups=a,b", func(opts *options.Options) {
opts.SkipAuthPreflight = true
opts.TrustedIPs = []string{"1.2.3.4"}
})
if err != nil {
t.Fatal(err)
}
test.req.Method = tc.method
test.req.RemoteAddr = tc.ip
if tc.withSession {
created := time.Now()
session := &sessions.SessionState{
Groups: tc.groups,
Email: "test",
AccessToken: "oauth_token",
CreatedAt: &created,
}
err = test.SaveSession(session)
}
assert.NoError(t, err)
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, tc.expectedStatusCode, test.rw.Code)
})
}
}