mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-24 08:52:25 +02:00
ce750e9b30
* Add the allowed_email_domains and the allowed_groups on the auth_request endpoint + support standard wildcard char for validation with sub-domain and email-domain. Signed-off-by: Valentin Pichard <github@w3st.fr> * Fix provider data initialisation * PKCE Support Adds Code Challenge PKCE support (RFC-7636) and partial Authorization Server Metadata (RFC-8414) for detecting PKCE support. - Introduces new option `--force-code-challenge-method` to force a specific code challenge method (either `S256` or `plain`) for instances when the server has not implemented RFC-8414 in order to detect PKCE support on the discovery document. - In all other cases, if the PKCE support can be determined during discovery then the `code_challenge_methods_supported` is used and S256 is always preferred. - The force command line argument is helpful with some providers like Azure who supports PKCE but does not list it in their discovery document yet. - Initial thought was given to just always attempt PKCE since according to spec additional URL parameters should be dropped by servers which implemented OAuth 2, however other projects found cases in the wild where this causes 500 errors by buggy implementations. See: https://github.com/spring-projects/spring-security/pull/7804#issuecomment-578323810 - Due to the fact that the `code_verifier` must be saved between the redirect and callback, sessions are now created when the redirect takes place with `Authenticated: false`. The session will be recreated and marked as `Authenticated` on callback. - Individual provider implementations can choose to include or ignore code_challenge and code_verifier function parameters passed to them Note: Technically speaking `plain` is not required to be implemented since oauth2-proxy will always be able to handle S256 and servers MUST implement S256 support. > If the client is capable of using "S256", it MUST use "S256", as "S256" > is Mandatory To Implement (MTI) on the server. Clients are permitted > to use "plain" only if they cannot support "S256" for some technical > reason and know via out-of-band configuration that the server supports > "plain". Ref: RFC-7636 Sec 4.2 oauth2-proxy will always use S256 unless the user explicitly forces `plain`. Fixes #1361 * Address PR comments by moving pkce generation * Make PKCE opt-in, move to using the Nonce generater for code verifier * Make PKCE opt-in, move to using the Nonce generater for code verifier * Encrypt CodeVerifier in CSRF Token instead of Session - Update Dex for PKCE support - Expose HTTPBin for further use cases * Correct the tests * Move code challenges into extra params * Correct typo in code challenge method Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk> * Correct the extra space in docs Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk> * Address changelog and new line nits * Add generated docs Co-authored-by: Valentin Pichard <github@w3st.fr> Co-authored-by: Joel Speed <joel.speed@hotmail.co.uk>
2792 lines
72 KiB
Go
2792 lines
72 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/mbland/hmacauth"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
|
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc"
|
|
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/cookie"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
const (
|
|
// The rawCookieSecret is 32 bytes and the base64CookieSecret is the base64
|
|
// encoded version of this.
|
|
rawCookieSecret = "secretthirtytwobytes+abcdefghijk"
|
|
base64CookieSecret = "c2VjcmV0dGhpcnR5dHdvYnl0ZXMrYWJjZGVmZ2hpams"
|
|
clientID = "3984n253984d7348dm8234yf982t"
|
|
clientSecret = "gv3498mfc9t23y23974dm2394dm9"
|
|
)
|
|
|
|
func init() {
|
|
logger.SetFlags(logger.Lshortfile)
|
|
}
|
|
|
|
func TestRobotsTxt(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/robots.txt", nil)
|
|
proxy.ServeHTTP(rw, req)
|
|
assert.Equal(t, 200, rw.Code)
|
|
assert.Equal(t, "User-agent: *\nDisallow: /\n", rw.Body.String())
|
|
}
|
|
|
|
type TestProvider struct {
|
|
*providers.ProviderData
|
|
EmailAddress string
|
|
ValidToken bool
|
|
GroupValidator func(string) bool
|
|
}
|
|
|
|
var _ providers.Provider = (*TestProvider)(nil)
|
|
|
|
func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
|
|
return &TestProvider{
|
|
ProviderData: &providers.ProviderData{
|
|
ProviderName: "Test Provider",
|
|
LoginURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: providerURL.Host,
|
|
Path: "/oauth/authorize",
|
|
},
|
|
RedeemURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: providerURL.Host,
|
|
Path: "/oauth/token",
|
|
},
|
|
ProfileURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: providerURL.Host,
|
|
Path: "/api/v1/profile",
|
|
},
|
|
Scope: "profile.email",
|
|
},
|
|
EmailAddress: emailAddress,
|
|
GroupValidator: func(s string) bool {
|
|
return true
|
|
},
|
|
}
|
|
}
|
|
|
|
func (tp *TestProvider) GetEmailAddress(_ context.Context, _ *sessions.SessionState) (string, error) {
|
|
return tp.EmailAddress, nil
|
|
}
|
|
|
|
func (tp *TestProvider) ValidateSession(_ context.Context, _ *sessions.SessionState) bool {
|
|
return tp.ValidToken
|
|
}
|
|
|
|
func Test_redeemCode(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
_, err = proxy.redeemCode(req, "")
|
|
assert.Equal(t, providers.ErrMissingCode, err)
|
|
}
|
|
|
|
func Test_enrichSession(t *testing.T) {
|
|
const (
|
|
sessionUser = "Mr Session"
|
|
sessionEmail = "session@example.com"
|
|
providerEmail = "provider@example.com"
|
|
)
|
|
|
|
testCases := map[string]struct {
|
|
session *sessions.SessionState
|
|
expectedUser string
|
|
expectedEmail string
|
|
}{
|
|
"Session already has enrichable fields": {
|
|
session: &sessions.SessionState{
|
|
User: sessionUser,
|
|
Email: sessionEmail,
|
|
},
|
|
expectedUser: sessionUser,
|
|
expectedEmail: sessionEmail,
|
|
},
|
|
"Session is missing Email and GetEmailAddress is implemented": {
|
|
session: &sessions.SessionState{
|
|
User: sessionUser,
|
|
},
|
|
expectedUser: sessionUser,
|
|
expectedEmail: providerEmail,
|
|
},
|
|
"Session is missing User and GetUserName is not implemented": {
|
|
session: &sessions.SessionState{
|
|
Email: sessionEmail,
|
|
},
|
|
expectedUser: "",
|
|
expectedEmail: sessionEmail,
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxy.provider = NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail)
|
|
|
|
err = proxy.enrichSessionState(context.Background(), tc.session)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tc.expectedUser, tc.session.User)
|
|
assert.Equal(t, tc.expectedEmail, tc.session.Email)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBasicAuthPassword(t *testing.T) {
|
|
providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
logger.Printf("%#v", r)
|
|
var payload string
|
|
switch r.URL.Path {
|
|
case "/oauth/token":
|
|
payload = `{"access_token": "my_auth_token"}`
|
|
default:
|
|
payload = r.Header.Get("Authorization")
|
|
if payload == "" {
|
|
payload = "No Authorization header found."
|
|
}
|
|
}
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte(payload))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}))
|
|
|
|
basicAuthPassword := "This is a secure password"
|
|
opts := baseTestOptions()
|
|
opts.UpstreamServers = options.UpstreamConfig{
|
|
Upstreams: []options.Upstream{
|
|
{
|
|
ID: providerServer.URL,
|
|
Path: "/",
|
|
URI: providerServer.URL,
|
|
},
|
|
},
|
|
}
|
|
|
|
opts.Cookie.Secure = false
|
|
opts.InjectRequestHeaders = []options.Header{
|
|
{
|
|
Name: "Authorization",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
BasicAuthPassword: &options.SecretSource{
|
|
Value: []byte(basicAuthPassword),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
providerURL, _ := url.Parse(providerServer.URL)
|
|
const emailAddress = "john.doe@example.com"
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxy.provider = NewTestProvider(providerURL, emailAddress)
|
|
|
|
// Save the required session
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
err = proxy.sessionStore.Save(rw, req, &sessions.SessionState{
|
|
Email: emailAddress,
|
|
})
|
|
assert.NoError(t, err)
|
|
|
|
// Extract the cookie value to inject into the test request
|
|
cookie := rw.Header().Values("Set-Cookie")[0]
|
|
|
|
req, _ = http.NewRequest("GET", "/", nil)
|
|
req.Header.Set("Cookie", cookie)
|
|
rw = httptest.NewRecorder()
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
// The username in the basic auth credentials is expected to be equal to the email address from the
|
|
// auth response, so we use the same variable here.
|
|
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+basicAuthPassword))
|
|
assert.Equal(t, expectedHeader, rw.Body.String())
|
|
providerServer.Close()
|
|
}
|
|
|
|
func TestPassGroupsHeadersWithGroups(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.InjectRequestHeaders = []options.Header{
|
|
{
|
|
Name: "X-Forwarded-Groups",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "groups",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
const emailAddress = "john.doe@example.com"
|
|
const userName = "9fcab5c9b889a557"
|
|
|
|
groups := []string{"a", "b"}
|
|
created := time.Now()
|
|
session := &sessions.SessionState{
|
|
User: userName,
|
|
Groups: groups,
|
|
Email: emailAddress,
|
|
AccessToken: "oauth_token",
|
|
CreatedAt: &created,
|
|
}
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
assert.NoError(t, err)
|
|
|
|
// Save the required session
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
err = proxy.sessionStore.Save(rw, req, session)
|
|
assert.NoError(t, err)
|
|
|
|
// Extract the cookie value to inject into the test request
|
|
cookie := rw.Header().Values("Set-Cookie")[0]
|
|
|
|
req, _ = http.NewRequest("GET", "/", nil)
|
|
req.Header.Set("Cookie", cookie)
|
|
rw = httptest.NewRecorder()
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
assert.Equal(t, []string{"a,b"}, req.Header["X-Forwarded-Groups"])
|
|
}
|
|
|
|
type PassAccessTokenTest struct {
|
|
providerServer *httptest.Server
|
|
proxy *OAuthProxy
|
|
opts *options.Options
|
|
}
|
|
|
|
type PassAccessTokenTestOptions struct {
|
|
PassAccessToken bool
|
|
ValidToken bool
|
|
ProxyUpstream options.Upstream
|
|
}
|
|
|
|
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) {
|
|
patt := &PassAccessTokenTest{}
|
|
|
|
patt.providerServer = httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var payload string
|
|
switch r.URL.Path {
|
|
case "/oauth/token":
|
|
payload = `{"access_token": "my_auth_token"}`
|
|
default:
|
|
payload = r.Header.Get("X-Forwarded-Access-Token")
|
|
if payload == "" {
|
|
payload = "No access token found."
|
|
}
|
|
}
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte(payload))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}))
|
|
|
|
patt.opts = baseTestOptions()
|
|
patt.opts.UpstreamServers = options.UpstreamConfig{
|
|
Upstreams: []options.Upstream{
|
|
{
|
|
ID: patt.providerServer.URL,
|
|
Path: "/",
|
|
URI: patt.providerServer.URL,
|
|
},
|
|
},
|
|
}
|
|
if opts.ProxyUpstream.ID != "" {
|
|
patt.opts.UpstreamServers.Upstreams = append(patt.opts.UpstreamServers.Upstreams, opts.ProxyUpstream)
|
|
}
|
|
|
|
patt.opts.Cookie.Secure = false
|
|
if opts.PassAccessToken {
|
|
patt.opts.InjectRequestHeaders = []options.Header{
|
|
{
|
|
Name: "X-Forwarded-Access-Token",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "access_token",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
err := validation.Validate(patt.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerURL, _ := url.Parse(patt.providerServer.URL)
|
|
const emailAddress = "michael.bland@gsa.gov"
|
|
|
|
testProvider := NewTestProvider(providerURL, emailAddress)
|
|
testProvider.ValidToken = opts.ValidToken
|
|
patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
patt.proxy.provider = testProvider
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return patt, nil
|
|
}
|
|
|
|
func (patTest *PassAccessTokenTest) Close() {
|
|
patTest.providerServer.Close()
|
|
}
|
|
|
|
func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie string) {
|
|
rw := httptest.NewRecorder()
|
|
|
|
csrf, err := cookies.NewCSRF(patTest.proxy.CookieOptions, "")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
req, err := http.NewRequest(
|
|
http.MethodGet,
|
|
fmt.Sprintf(
|
|
"/oauth2/callback?code=callback_code&state=%s",
|
|
encodeState(csrf.HashOAuthState(), "%2F"),
|
|
),
|
|
strings.NewReader(""),
|
|
)
|
|
if err != nil {
|
|
return 0, ""
|
|
}
|
|
|
|
// rw is a dummy here, we just want the csrfCookie to add to our req
|
|
csrfCookie, err := csrf.SetCookie(httptest.NewRecorder(), req)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
req.AddCookie(csrfCookie)
|
|
|
|
patTest.proxy.ServeHTTP(rw, req)
|
|
|
|
if len(rw.Header().Values("Set-Cookie")) >= 2 {
|
|
cookie = rw.Header().Values("Set-Cookie")[1]
|
|
}
|
|
|
|
return rw.Code, cookie
|
|
}
|
|
|
|
// getEndpointWithCookie makes a requests againt the oauthproxy with passed requestPath
|
|
// and cookie and returns body and status code.
|
|
func (patTest *PassAccessTokenTest) getEndpointWithCookie(cookie string, endpoint string) (httpCode int, accessToken string) {
|
|
cookieName := patTest.proxy.CookieOptions.Name
|
|
var value string
|
|
keyPrefix := cookieName + "="
|
|
|
|
for _, field := range strings.Split(cookie, "; ") {
|
|
value = strings.TrimPrefix(field, keyPrefix)
|
|
if value != field {
|
|
break
|
|
} else {
|
|
value = ""
|
|
}
|
|
}
|
|
if value == "" {
|
|
return 0, ""
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", endpoint, strings.NewReader(""))
|
|
if err != nil {
|
|
return 0, ""
|
|
}
|
|
req.AddCookie(&http.Cookie{
|
|
Name: cookieName,
|
|
Value: value,
|
|
Path: "/",
|
|
Expires: time.Now().Add(time.Duration(24)),
|
|
HttpOnly: true,
|
|
})
|
|
|
|
rw := httptest.NewRecorder()
|
|
patTest.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Body.String()
|
|
}
|
|
|
|
func TestForwardAccessTokenUpstream(t *testing.T) {
|
|
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: true,
|
|
ValidToken: true,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(patTest.Close)
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := patTest.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotNil(t, cookie)
|
|
|
|
// Now we make a regular request; the access_token from the cookie is
|
|
// forwarded as the "X-Forwarded-Access-Token" header. The token is
|
|
// read by the test provider server and written in the response body.
|
|
code, payload := patTest.getEndpointWithCookie(cookie, "/")
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "my_auth_token", payload)
|
|
}
|
|
|
|
func TestStaticProxyUpstream(t *testing.T) {
|
|
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: true,
|
|
ValidToken: true,
|
|
ProxyUpstream: options.Upstream{
|
|
ID: "static-proxy",
|
|
Path: "/static-proxy",
|
|
Static: true,
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(patTest.Close)
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := patTest.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotEqual(t, nil, cookie)
|
|
|
|
// Now we make a regular request against the upstream proxy; And validate
|
|
// the returned status code through the static proxy.
|
|
code, payload := patTest.getEndpointWithCookie(cookie, "/static-proxy")
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "Authenticated", payload)
|
|
}
|
|
|
|
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
|
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: false,
|
|
ValidToken: true,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(patTest.Close)
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := patTest.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotEqual(t, nil, cookie)
|
|
|
|
// Now we make a regular request, but the access token header should
|
|
// not be present.
|
|
code, payload := patTest.getEndpointWithCookie(cookie, "/")
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "No access token found.", payload)
|
|
}
|
|
|
|
func TestSessionValidationFailure(t *testing.T) {
|
|
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
ValidToken: false,
|
|
})
|
|
require.NoError(t, err)
|
|
t.Cleanup(patTest.Close)
|
|
|
|
// An unsuccessful validation will return 403 and not set the auth cookie.
|
|
code, cookie := patTest.getCallbackEndpoint()
|
|
assert.Equal(t, http.StatusForbidden, code)
|
|
assert.Equal(t, "", cookie)
|
|
}
|
|
|
|
type SignInPageTest struct {
|
|
opts *options.Options
|
|
proxy *OAuthProxy
|
|
signInRegexp *regexp.Regexp
|
|
signInProviderRegexp *regexp.Regexp
|
|
}
|
|
|
|
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
|
|
const signInSkipProvider = `>Found<`
|
|
|
|
func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) {
|
|
var sipTest SignInPageTest
|
|
|
|
sipTest.opts = baseTestOptions()
|
|
sipTest.opts.SkipProviderButton = skipProvider
|
|
err := validation.Validate(sipTest.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool {
|
|
return true
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
|
|
sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
|
|
|
|
return &sipTest, nil
|
|
}
|
|
|
|
func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
|
|
sipTest.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Body.String()
|
|
}
|
|
|
|
type AlwaysSuccessfulValidator struct {
|
|
}
|
|
|
|
func (AlwaysSuccessfulValidator) Validate(user, password string) bool {
|
|
return true
|
|
}
|
|
|
|
func TestManualSignInStoresUserGroupsInTheSession(t *testing.T) {
|
|
userGroups := []string{"somegroup", "someothergroup"}
|
|
|
|
opts := baseTestOptions()
|
|
opts.HtpasswdUserGroups = userGroups
|
|
err := validation.Validate(opts)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
|
return true
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxy.basicAuthValidator = AlwaysSuccessfulValidator{}
|
|
|
|
rw := httptest.NewRecorder()
|
|
formData := url.Values{}
|
|
formData.Set("username", "someuser")
|
|
formData.Set("password", "somepass")
|
|
signInReq, _ := http.NewRequest(http.MethodPost, "/oauth2/sign_in", strings.NewReader(formData.Encode()))
|
|
signInReq.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
proxy.ServeHTTP(rw, signInReq)
|
|
|
|
assert.Equal(t, http.StatusFound, rw.Code)
|
|
|
|
req, _ := http.NewRequest(http.MethodGet, "/something", strings.NewReader(formData.Encode()))
|
|
for _, c := range rw.Result().Cookies() {
|
|
req.AddCookie(c)
|
|
}
|
|
|
|
s, err := proxy.sessionStore.Load(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
assert.Equal(t, userGroups, s.Groups)
|
|
}
|
|
|
|
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
const endpoint = "/some/random/endpoint"
|
|
|
|
code, body := sipTest.GetEndpoint(endpoint)
|
|
assert.Equal(t, 403, code)
|
|
|
|
match := sipTest.signInRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInRedirectPattern + "\nBody:\n" + body)
|
|
}
|
|
if match[1] != endpoint {
|
|
t.Fatal(`expected redirect to "` + endpoint +
|
|
`", but was "` + match[1] + `"`)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
code, body := sipTest.GetEndpoint("/oauth2/sign_in")
|
|
assert.Equal(t, 200, code)
|
|
|
|
match := sipTest.signInRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInRedirectPattern + "\nBody:\n" + body)
|
|
}
|
|
if match[1] != "/" {
|
|
t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageSkipProvider(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
endpoint := "/some/random/endpoint"
|
|
|
|
code, body := sipTest.GetEndpoint(endpoint)
|
|
assert.Equal(t, 302, code)
|
|
|
|
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInSkipProvider + "\nBody:\n" + body)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageSkipProviderDirect(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
endpoint := "/sign_in"
|
|
|
|
code, body := sipTest.GetEndpoint(endpoint)
|
|
assert.Equal(t, 302, code)
|
|
|
|
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInSkipProvider + "\nBody:\n" + body)
|
|
}
|
|
}
|
|
|
|
type ProcessCookieTest struct {
|
|
opts *options.Options
|
|
proxy *OAuthProxy
|
|
rw *httptest.ResponseRecorder
|
|
req *http.Request
|
|
validateUser bool
|
|
}
|
|
|
|
type ProcessCookieTestOpts struct {
|
|
providerValidateCookieResponse bool
|
|
}
|
|
|
|
type OptionsModifier func(*options.Options)
|
|
|
|
func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
for _, modifier := range modifiers {
|
|
modifier(pcTest.opts)
|
|
}
|
|
// First, set the CookieRefresh option so proxy.AesCipher is created,
|
|
// needed to encrypt the access_token.
|
|
pcTest.opts.Cookie.Refresh = time.Hour
|
|
err := validation.Validate(pcTest.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
testProvider := &TestProvider{
|
|
ProviderData: &providers.ProviderData{},
|
|
ValidToken: opts.providerValidateCookieResponse,
|
|
}
|
|
|
|
groups := pcTest.opts.Providers[0].AllowedGroups
|
|
testProvider.AllowedGroups = make(map[string]struct{}, len(groups))
|
|
for _, group := range groups {
|
|
testProvider.AllowedGroups[group] = struct{}{}
|
|
}
|
|
pcTest.proxy.provider = testProvider
|
|
|
|
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
|
// access_token validation.
|
|
pcTest.proxy.CookieOptions.Refresh = time.Duration(0)
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
|
|
pcTest.validateUser = true
|
|
return &pcTest, nil
|
|
}
|
|
|
|
func NewProcessCookieTestWithDefaults() (*ProcessCookieTest, error) {
|
|
return NewProcessCookieTest(ProcessCookieTestOpts{
|
|
providerValidateCookieResponse: true,
|
|
})
|
|
}
|
|
|
|
func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
|
return NewProcessCookieTest(ProcessCookieTestOpts{
|
|
providerValidateCookieResponse: true,
|
|
}, modifiers...)
|
|
}
|
|
|
|
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error {
|
|
err := p.proxy.SaveSession(p.rw, p.req, s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, cookie := range p.rw.Result().Cookies() {
|
|
p.req.AddCookie(cookie)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) {
|
|
return p.proxy.LoadCookiedSession(p.req)
|
|
}
|
|
|
|
func TestLoadCookiedSession(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithDefaults()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
assert.Equal(t, startSession.Email, session.Email)
|
|
assert.Equal(t, "", session.User)
|
|
assert.Equal(t, startSession.AccessToken, session.AccessToken)
|
|
}
|
|
|
|
func TestProcessCookieNoCookieError(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithDefaults()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.Error(t, err, "cookie \"_oauth2_proxy\" not present")
|
|
if session != nil {
|
|
t.Errorf("expected nil session. got %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestProcessCookieRefreshNotSet(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(23) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
|
|
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.Equal(t, nil, err)
|
|
if session.Age() < time.Duration(-2)*time.Hour {
|
|
t.Errorf("cookie too young %v", session.Age())
|
|
}
|
|
assert.Equal(t, startSession.Email, session.Email)
|
|
}
|
|
|
|
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.NotEqual(t, nil, err)
|
|
if session != nil {
|
|
t.Errorf("expected nil session %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.CookieOptions.Refresh = time.Hour
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.NotEqual(t, nil, err)
|
|
if session != nil {
|
|
t.Errorf("expected nil session %#v", session)
|
|
}
|
|
}
|
|
|
|
func NewUserInfoEndpointTest() (*ProcessCookieTest, error) {
|
|
pcTest, err := NewProcessCookieTestWithDefaults()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/userinfo", nil)
|
|
return pcTest, nil
|
|
}
|
|
|
|
func TestUserInfoEndpointAccepted(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
session *sessions.SessionState
|
|
expectedResponse string
|
|
}{
|
|
{
|
|
name: "Full session",
|
|
session: &sessions.SessionState{
|
|
User: "john.doe",
|
|
Email: "john.doe@example.com",
|
|
Groups: []string{"example", "groups"},
|
|
AccessToken: "my_access_token",
|
|
},
|
|
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"]}\n",
|
|
},
|
|
{
|
|
name: "Minimal session",
|
|
session: &sessions.SessionState{
|
|
User: "john.doe",
|
|
Email: "john.doe@example.com",
|
|
Groups: []string{"example", "groups"},
|
|
},
|
|
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"]}\n",
|
|
},
|
|
{
|
|
name: "No groups",
|
|
session: &sessions.SessionState{
|
|
User: "john.doe",
|
|
Email: "john.doe@example.com",
|
|
AccessToken: "my_access_token",
|
|
},
|
|
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\"}\n",
|
|
},
|
|
{
|
|
name: "With Preferred Username",
|
|
session: &sessions.SessionState{
|
|
User: "john.doe",
|
|
PreferredUsername: "john",
|
|
Email: "john.doe@example.com",
|
|
Groups: []string{"example", "groups"},
|
|
AccessToken: "my_access_token",
|
|
},
|
|
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"],\"preferredUsername\":\"john\"}\n",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
test, err := NewUserInfoEndpointTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = test.SaveSession(tc.session)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusOK, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, tc.expectedResponse, string(bodyBytes))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
|
test, err := NewUserInfoEndpointTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
}
|
|
|
|
func TestEncodedUrlsStayEncoded(t *testing.T) {
|
|
encodeTest, err := NewSignInPageTest(false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
code, _ := encodeTest.GetEndpoint("/%2F/test1/%2F/test2")
|
|
assert.Equal(t, 403, code)
|
|
}
|
|
|
|
func NewAuthOnlyEndpointTest(querystring string, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pcTest.req, _ = http.NewRequest(
|
|
"GET",
|
|
fmt.Sprintf("%s/auth%s", pcTest.opts.ProxyPrefix, querystring),
|
|
nil)
|
|
return pcTest, nil
|
|
}
|
|
|
|
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest("")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
|
|
err = test.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusAccepted, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest("")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &sessions.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = test.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest("")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
|
|
err = test.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
test.validateUser = false
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
pcTest.opts.InjectResponseHeaders = []options.Header{
|
|
{
|
|
Name: "X-Auth-Request-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Groups",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "groups",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-Preferred-Username",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "preferred_username",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
pcTest.opts.Providers[0].AllowedGroups = []string{"oauth_groups"}
|
|
err := validation.Validate(pcTest.opts)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ProviderData: &providers.ProviderData{},
|
|
ValidToken: true,
|
|
}
|
|
|
|
pcTest.validateUser = true
|
|
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
User: "oauth_user", Groups: []string{"oauth_groups"}, Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
|
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
|
assert.Equal(t, "oauth_user", pcTest.rw.Header().Get("X-Auth-Request-User"))
|
|
assert.Equal(t, startSession.Groups, pcTest.rw.Header().Values("X-Auth-Request-Groups"))
|
|
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Get("X-Auth-Request-Email"))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
pcTest.opts.InjectResponseHeaders = []options.Header{
|
|
{
|
|
Name: "X-Auth-Request-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Groups",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "groups",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-Preferred-Username",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "preferred_username",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "Authorization",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
BasicAuthPassword: &options.SecretSource{
|
|
Value: []byte("This is a secure password"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
err := validation.Validate(pcTest.opts)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ProviderData: &providers.ProviderData{},
|
|
ValidToken: true,
|
|
}
|
|
|
|
pcTest.validateUser = true
|
|
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
|
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
|
assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
|
|
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
|
|
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("oauth_user:This is a secure password"))
|
|
assert.Equal(t, expectedHeader, pcTest.rw.Header().Values("Authorization")[0])
|
|
}
|
|
|
|
func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
pcTest.opts.InjectResponseHeaders = []options.Header{
|
|
{
|
|
Name: "X-Auth-Request-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Groups",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "groups",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-Preferred-Username",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "preferred_username",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
err := validation.Validate(pcTest.opts)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ProviderData: &providers.ProviderData{},
|
|
ValidToken: true,
|
|
}
|
|
|
|
pcTest.validateUser = true
|
|
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
|
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
|
assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
|
|
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
|
|
assert.Equal(t, 0, len(pcTest.rw.Header().Values("Authorization")), "should not have Authorization header entries")
|
|
}
|
|
|
|
func TestAuthSkippedForPreflightRequests(t *testing.T) {
|
|
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte("response"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}))
|
|
t.Cleanup(upstreamServer.Close)
|
|
|
|
opts := baseTestOptions()
|
|
opts.UpstreamServers = options.UpstreamConfig{
|
|
Upstreams: []options.Upstream{
|
|
{
|
|
ID: upstreamServer.URL,
|
|
Path: "/",
|
|
URI: upstreamServer.URL,
|
|
},
|
|
},
|
|
}
|
|
opts.SkipAuthPreflight = true
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
upstreamURL, _ := url.Parse(upstreamServer.URL)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxy.provider = NewTestProvider(upstreamURL, "")
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
assert.Equal(t, 200, rw.Code)
|
|
assert.Equal(t, "response", rw.Body.String())
|
|
}
|
|
|
|
type SignatureAuthenticator struct {
|
|
auth hmacauth.HmacAuth
|
|
}
|
|
|
|
func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) {
|
|
result, headerSig, computedSig := v.auth.AuthenticateRequest(r)
|
|
|
|
var msg string
|
|
switch result {
|
|
case hmacauth.ResultNoSignature:
|
|
msg = "no signature received"
|
|
case hmacauth.ResultMatch:
|
|
msg = "signatures match"
|
|
case hmacauth.ResultMismatch:
|
|
msg = fmt.Sprintf(
|
|
"signatures do not match:\n received: %s\n computed: %s",
|
|
headerSig,
|
|
computedSig)
|
|
default:
|
|
panic("unknown result value: " + result.String())
|
|
}
|
|
|
|
_, err := w.Write([]byte(msg))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
type SignatureTest struct {
|
|
opts *options.Options
|
|
upstream *httptest.Server
|
|
upstreamHost string
|
|
provider *httptest.Server
|
|
header http.Header
|
|
rw *httptest.ResponseRecorder
|
|
authenticator *SignatureAuthenticator
|
|
authProvider providers.Provider
|
|
}
|
|
|
|
func NewSignatureTest() (*SignatureTest, error) {
|
|
opts := baseTestOptions()
|
|
opts.EmailDomains = []string{"acm.org"}
|
|
|
|
authenticator := &SignatureAuthenticator{}
|
|
upstreamServer := httptest.NewServer(
|
|
http.HandlerFunc(authenticator.Authenticate))
|
|
upstreamURL, err := url.Parse(upstreamServer.URL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
opts.UpstreamServers = options.UpstreamConfig{
|
|
Upstreams: []options.Upstream{
|
|
{
|
|
ID: upstreamServer.URL,
|
|
Path: "/",
|
|
URI: upstreamServer.URL,
|
|
},
|
|
},
|
|
}
|
|
|
|
providerHandler := func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte(`{"access_token": "my_auth_token"}`))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
provider := httptest.NewServer(http.HandlerFunc(providerHandler))
|
|
providerURL, err := url.Parse(provider.URL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
testProvider := NewTestProvider(providerURL, "mbland@acm.org")
|
|
|
|
return &SignatureTest{
|
|
opts,
|
|
upstreamServer,
|
|
upstreamURL.Host,
|
|
provider,
|
|
make(http.Header),
|
|
httptest.NewRecorder(),
|
|
authenticator,
|
|
testProvider,
|
|
}, nil
|
|
}
|
|
|
|
func (st *SignatureTest) Close() {
|
|
st.provider.Close()
|
|
st.upstream.Close()
|
|
}
|
|
|
|
// fakeNetConn simulates an http.Request.Body buffer that will be consumed
|
|
// when it is read by the hmacauth.HmacAuth if not handled properly. See:
|
|
// https://github.com/18F/hmacauth/pull/4
|
|
type fakeNetConn struct {
|
|
reqBody string
|
|
}
|
|
|
|
func (fnc *fakeNetConn) Read(p []byte) (n int, err error) {
|
|
if bodyLen := len(fnc.reqBody); bodyLen != 0 {
|
|
copy(p, fnc.reqBody)
|
|
fnc.reqBody = ""
|
|
return bodyLen, io.EOF
|
|
}
|
|
return 0, io.EOF
|
|
}
|
|
|
|
func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) error {
|
|
err := validation.Validate(st.opts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true })
|
|
if err != nil {
|
|
return err
|
|
}
|
|
proxy.provider = st.authProvider
|
|
|
|
var bodyBuf io.ReadCloser
|
|
if body != "" {
|
|
bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body})
|
|
}
|
|
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
|
|
req.Header = st.header
|
|
|
|
state := &sessions.SessionState{
|
|
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
|
err = proxy.SaveSession(st.rw, req, state)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, c := range st.rw.Result().Cookies() {
|
|
req.AddCookie(c)
|
|
}
|
|
// This is used by the upstream to validate the signature.
|
|
st.authenticator.auth = hmacauth.NewHmacAuth(
|
|
crypto.SHA1, []byte(key), upstream.SignatureHeader, upstream.SignatureHeaders)
|
|
proxy.ServeHTTP(st.rw, req)
|
|
|
|
return nil
|
|
}
|
|
|
|
func TestRequestSignature(t *testing.T) {
|
|
testCases := map[string]struct {
|
|
method string
|
|
body string
|
|
key string
|
|
resp string
|
|
}{
|
|
"No request signature": {
|
|
method: "GET",
|
|
body: "",
|
|
key: "",
|
|
resp: "no signature received",
|
|
},
|
|
"Get request": {
|
|
method: "GET",
|
|
body: "",
|
|
key: "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d",
|
|
resp: "signatures match",
|
|
},
|
|
"Post request": {
|
|
method: "POST",
|
|
body: `{ "hello": "world!" }`,
|
|
key: "d90df39e2d19282840252612dd7c81421a372f61",
|
|
resp: "signatures match",
|
|
},
|
|
}
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
st, err := NewSignatureTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(st.Close)
|
|
if tc.key != "" {
|
|
st.opts.SignatureKey = fmt.Sprintf("sha1:%s", tc.key)
|
|
}
|
|
err = st.MakeRequestWithExpectedKey(tc.method, tc.body, tc.key)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 200, st.rw.Code)
|
|
assert.Equal(t, tc.resp, st.rw.Body.String())
|
|
})
|
|
}
|
|
}
|
|
|
|
type ajaxRequestTest struct {
|
|
opts *options.Options
|
|
proxy *OAuthProxy
|
|
}
|
|
|
|
func newAjaxRequestTest(forceJSONErrors bool) (*ajaxRequestTest, error) {
|
|
test := &ajaxRequestTest{}
|
|
test.opts = baseTestOptions()
|
|
test.opts.ForceJSONErrors = forceJSONErrors
|
|
err := validation.Validate(test.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool {
|
|
return true
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return test, nil
|
|
}
|
|
|
|
func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, []byte, error) {
|
|
rw := httptest.NewRecorder()
|
|
req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader(""))
|
|
if err != nil {
|
|
return 0, nil, nil, err
|
|
}
|
|
req.Header = header
|
|
test.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Header(), rw.Body.Bytes(), nil
|
|
}
|
|
|
|
func testAjaxUnauthorizedRequest(t *testing.T, header http.Header, forceJSONErrors bool) {
|
|
test, err := newAjaxRequestTest(forceJSONErrors)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
endpoint := "/test"
|
|
|
|
code, rh, body, err := test.getEndpoint(endpoint, header)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, http.StatusUnauthorized, code)
|
|
mime := rh.Get("Content-Type")
|
|
assert.Equal(t, applicationJSON, mime)
|
|
assert.Equal(t, []byte("{}"), body)
|
|
}
|
|
func TestAjaxUnauthorizedRequest1(t *testing.T) {
|
|
header := make(http.Header)
|
|
header.Add("accept", applicationJSON)
|
|
|
|
testAjaxUnauthorizedRequest(t, header, false)
|
|
}
|
|
|
|
func TestAjaxUnauthorizedRequest2(t *testing.T) {
|
|
header := make(http.Header)
|
|
header.Add("Accept", applicationJSON)
|
|
|
|
testAjaxUnauthorizedRequest(t, header, false)
|
|
}
|
|
|
|
func TestAjaxUnauthorizedRequestAccept1(t *testing.T) {
|
|
header := make(http.Header)
|
|
header.Add("Accept", "application/json, text/plain, */*")
|
|
|
|
testAjaxUnauthorizedRequest(t, header, false)
|
|
}
|
|
|
|
func TestForceJSONErrorsUnauthorizedRequest(t *testing.T) {
|
|
testAjaxUnauthorizedRequest(t, nil, true)
|
|
}
|
|
|
|
func TestAjaxForbiddendRequest(t *testing.T) {
|
|
test, err := newAjaxRequestTest(false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
endpoint := "/test"
|
|
header := make(http.Header)
|
|
code, rh, _, err := test.getEndpoint(endpoint, header)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, http.StatusForbidden, code)
|
|
mime := rh.Get("Content-Type")
|
|
assert.NotEqual(t, applicationJSON, mime)
|
|
}
|
|
|
|
func TestClearSplitCookie(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.Cookie.Secret = base64CookieSecret
|
|
opts.Cookie.Name = "oauth2"
|
|
opts.Cookie.Domains = []string{"abc"}
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
p := OAuthProxy{CookieOptions: &opts.Cookie, sessionStore: store}
|
|
var rw = httptest.NewRecorder()
|
|
req := httptest.NewRequest("get", "/", nil)
|
|
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "test1",
|
|
Value: "test1",
|
|
})
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "oauth2_0",
|
|
Value: "oauth2_0",
|
|
})
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "oauth2_1",
|
|
Value: "oauth2_1",
|
|
})
|
|
|
|
err = p.ClearSessionCookie(rw, req)
|
|
assert.NoError(t, err)
|
|
header := rw.Header()
|
|
|
|
assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries")
|
|
}
|
|
|
|
func TestClearSingleCookie(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.Cookie.Name = "oauth2"
|
|
opts.Cookie.Domains = []string{"abc"}
|
|
store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
p := OAuthProxy{CookieOptions: &opts.Cookie, sessionStore: store}
|
|
var rw = httptest.NewRecorder()
|
|
req := httptest.NewRequest("get", "/", nil)
|
|
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "test1",
|
|
Value: "test1",
|
|
})
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "oauth2",
|
|
Value: "oauth2",
|
|
})
|
|
|
|
err = p.ClearSessionCookie(rw, req)
|
|
assert.NoError(t, err)
|
|
header := rw.Header()
|
|
|
|
assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries")
|
|
}
|
|
|
|
type NoOpKeySet struct {
|
|
}
|
|
|
|
func (NoOpKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
|
|
splitStrings := strings.Split(jwt, ".")
|
|
payloadString := splitStrings[1]
|
|
return base64.RawURLEncoding.DecodeString(payloadString)
|
|
}
|
|
|
|
func TestGetJwtSession(t *testing.T) {
|
|
/* token payload:
|
|
{
|
|
"sub": "1234567890",
|
|
"aud": "https://test.myapp.com",
|
|
"name": "John Doe",
|
|
"email": "john@example.com",
|
|
"iss": "https://issuer.example.com",
|
|
"iat": 1553691215,
|
|
"exp": 1912151821
|
|
}
|
|
*/
|
|
goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
|
|
"eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" +
|
|
"WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" +
|
|
"E1LCJleHAiOjE5MTIxNTE4MjF9." +
|
|
"rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" +
|
|
"OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8"
|
|
|
|
keyset := NoOpKeySet{}
|
|
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
|
|
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true,
|
|
SkipClientIDCheck: true})
|
|
verificationOptions := internaloidc.IDTokenVerificationOptions{
|
|
AudienceClaims: []string{"aud"},
|
|
ClientID: "https://test.myapp.com",
|
|
ExtraAudiences: []string{},
|
|
}
|
|
internalVerifier := internaloidc.NewVerifier(verifier, verificationOptions)
|
|
|
|
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
|
|
opts.InjectRequestHeaders = []options.Header{
|
|
{
|
|
Name: "Authorization",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "id_token",
|
|
Prefix: "Bearer ",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
opts.InjectResponseHeaders = []options.Header{
|
|
{
|
|
Name: "Authorization",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "id_token",
|
|
Prefix: "Bearer ",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
opts.SkipJwtBearerTokens = true
|
|
opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), internalVerifier))
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tp, _ := test.proxy.provider.(*TestProvider)
|
|
tp.GroupValidator = func(s string) bool {
|
|
return true
|
|
}
|
|
|
|
authHeader := fmt.Sprintf("Bearer %s", goodJwt)
|
|
test.req.Header = map[string][]string{
|
|
"Authorization": {authHeader},
|
|
}
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
if test.rw.Code >= 400 {
|
|
t.Fatalf("expected 3xx got %d", test.rw.Code)
|
|
}
|
|
|
|
// Check PassAuthorization, should overwrite Basic header
|
|
assert.Equal(t, test.req.Header.Get("Authorization"), authHeader)
|
|
assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "1234567890")
|
|
assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com")
|
|
|
|
// SetAuthorization and SetXAuthRequest
|
|
assert.Equal(t, test.rw.Header().Get("Authorization"), authHeader)
|
|
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "1234567890")
|
|
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
|
|
}
|
|
|
|
func Test_prepareNoCache(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
prepareNoCache(w)
|
|
})
|
|
mux := http.NewServeMux()
|
|
mux.Handle("/", handler)
|
|
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
for k, v := range noCacheHeaders {
|
|
assert.Equal(t, rec.Header().Get(k), v)
|
|
}
|
|
}
|
|
|
|
func Test_noCacheHeaders(t *testing.T) {
|
|
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte("upstream"))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
}))
|
|
t.Cleanup(upstreamServer.Close)
|
|
|
|
opts := baseTestOptions()
|
|
opts.UpstreamServers = options.UpstreamConfig{
|
|
Upstreams: []options.Upstream{
|
|
{
|
|
ID: upstreamServer.URL,
|
|
Path: "/",
|
|
URI: upstreamServer.URL,
|
|
},
|
|
},
|
|
}
|
|
opts.SkipAuthRegex = []string{".*"}
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Run("not exist in response from upstream", func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/upstream", nil)
|
|
proxy.ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "upstream", rec.Body.String())
|
|
|
|
// checking noCacheHeaders does not exists in response headers from upstream
|
|
for k := range noCacheHeaders {
|
|
assert.Equal(t, "", rec.Header().Get(k))
|
|
}
|
|
})
|
|
|
|
t.Run("has no-cache", func(t *testing.T) {
|
|
tests := []struct {
|
|
path string
|
|
hasNoCache bool
|
|
}{
|
|
{
|
|
path: "/oauth2/sign_in",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/sign_out",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/start",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/callback",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/auth",
|
|
hasNoCache: false,
|
|
},
|
|
{
|
|
path: "/oauth2/userinfo",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/upstream",
|
|
hasNoCache: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.path, func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, tt.path, nil)
|
|
proxy.ServeHTTP(rec, req)
|
|
cacheControl := rec.Result().Header.Get("Cache-Control")
|
|
if tt.hasNoCache != (strings.Contains(cacheControl, "no-cache")) {
|
|
t.Errorf(`unexpected "Cache-Control" header: %s`, cacheControl)
|
|
}
|
|
})
|
|
}
|
|
|
|
})
|
|
}
|
|
|
|
func baseTestOptions() *options.Options {
|
|
opts := options.NewOptions()
|
|
opts.Cookie.Secret = rawCookieSecret
|
|
opts.Providers[0].ID = "providerID"
|
|
opts.Providers[0].ClientID = clientID
|
|
opts.Providers[0].ClientSecret = clientSecret
|
|
opts.EmailDomains = []string{"*"}
|
|
|
|
// Default injected headers for legacy configuration
|
|
opts.InjectRequestHeaders = []options.Header{
|
|
{
|
|
Name: "Authorization",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
BasicAuthPassword: &options.SecretSource{
|
|
Value: []byte(base64.StdEncoding.EncodeToString([]byte("This is a secure password"))),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
return opts
|
|
}
|
|
|
|
func TestTrustedIPs(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
trustedIPs []string
|
|
reverseProxy bool
|
|
realClientIPHeader string
|
|
req *http.Request
|
|
expectTrusted bool
|
|
}{
|
|
// Check unconfigured behavior.
|
|
{
|
|
name: "Default",
|
|
trustedIPs: nil,
|
|
reverseProxy: false,
|
|
realClientIPHeader: "X-Real-IP", // Default value
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check using req.RemoteAddr (Options.ReverseProxy == false).
|
|
{
|
|
name: "WithRemoteAddr",
|
|
trustedIPs: []string{"127.0.0.1"},
|
|
reverseProxy: false,
|
|
realClientIPHeader: "X-Real-IP", // Default value
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "127.0.0.1:43670"
|
|
return req
|
|
}(),
|
|
expectTrusted: true,
|
|
},
|
|
// Check ignores req.RemoteAddr match when behind a reverse proxy / missing header.
|
|
{
|
|
name: "IgnoresRemoteAddrInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Real-IP", // Default value
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "127.0.0.1:44324"
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check successful trusting of localhost in IPv4.
|
|
{
|
|
name: "TrustsLocalhostInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "127.0.0.1")
|
|
return req
|
|
}(),
|
|
expectTrusted: true,
|
|
},
|
|
// Check successful trusting of localhost in IPv6.
|
|
{
|
|
name: "TrustsIP6LocalostInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "::1")
|
|
return req
|
|
}(),
|
|
expectTrusted: true,
|
|
},
|
|
// Check does not trust random IPv4 address.
|
|
{
|
|
name: "DoesNotTrustRandomIP4Address",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "12.34.56.78")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check does not trust random IPv6 address.
|
|
{
|
|
name: "DoesNotTrustRandomIP6Address",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "::2")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check respects correct header.
|
|
{
|
|
name: "RespectsCorrectHeaderInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Real-IP", "::1")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check doesn't trust if garbage is provided.
|
|
{
|
|
name: "DoesNotTrustGarbageInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "adsfljk29242as!!")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check doesn't trust if garbage is provided (no reverse-proxy).
|
|
{
|
|
name: "DoesNotTrustGarbage",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: false,
|
|
realClientIPHeader: "X-Real-IP",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "adsfljk29242as!!"
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.UpstreamServers = options.UpstreamConfig{
|
|
Upstreams: []options.Upstream{
|
|
{
|
|
ID: "static",
|
|
Path: "/",
|
|
Static: true,
|
|
},
|
|
},
|
|
}
|
|
opts.TrustedIPs = tt.trustedIPs
|
|
opts.ReverseProxy = tt.reverseProxy
|
|
opts.RealClientIPHeader = tt.realClientIPHeader
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
assert.NoError(t, err)
|
|
rw := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rw, tt.req)
|
|
if tt.expectTrusted {
|
|
assert.Equal(t, 200, rw.Code)
|
|
} else {
|
|
assert.Equal(t, 403, rw.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_buildRoutesAllowlist(t *testing.T) {
|
|
type expectedAllowedRoute struct {
|
|
method string
|
|
regexString string
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
skipAuthRegex []string
|
|
skipAuthRoutes []string
|
|
expectedRoutes []expectedAllowedRoute
|
|
shouldError bool
|
|
}{
|
|
{
|
|
name: "No skip auth configured",
|
|
skipAuthRegex: []string{},
|
|
skipAuthRoutes: []string{},
|
|
expectedRoutes: []expectedAllowedRoute{},
|
|
shouldError: false,
|
|
},
|
|
{
|
|
name: "Only skipAuthRegex configured",
|
|
skipAuthRegex: []string{
|
|
"^/foo/bar",
|
|
"^/baz/[0-9]+/thing",
|
|
},
|
|
skipAuthRoutes: []string{},
|
|
expectedRoutes: []expectedAllowedRoute{
|
|
{
|
|
method: "",
|
|
regexString: "^/foo/bar",
|
|
},
|
|
{
|
|
method: "",
|
|
regexString: "^/baz/[0-9]+/thing",
|
|
},
|
|
},
|
|
shouldError: false,
|
|
},
|
|
{
|
|
name: "Only skipAuthRoutes configured",
|
|
skipAuthRegex: []string{},
|
|
skipAuthRoutes: []string{
|
|
"GET=^/foo/bar",
|
|
"POST=^/baz/[0-9]+/thing",
|
|
"^/all/methods$",
|
|
"WEIRD=^/methods/are/allowed",
|
|
"PATCH=/second/equals?are=handled&just=fine",
|
|
},
|
|
expectedRoutes: []expectedAllowedRoute{
|
|
{
|
|
method: "GET",
|
|
regexString: "^/foo/bar",
|
|
},
|
|
{
|
|
method: "POST",
|
|
regexString: "^/baz/[0-9]+/thing",
|
|
},
|
|
{
|
|
method: "",
|
|
regexString: "^/all/methods$",
|
|
},
|
|
{
|
|
method: "WEIRD",
|
|
regexString: "^/methods/are/allowed",
|
|
},
|
|
{
|
|
method: "PATCH",
|
|
regexString: "/second/equals?are=handled&just=fine",
|
|
},
|
|
},
|
|
shouldError: false,
|
|
},
|
|
{
|
|
name: "Both skipAuthRegexes and skipAuthRoutes configured",
|
|
skipAuthRegex: []string{
|
|
"^/foo/bar/regex",
|
|
"^/baz/[0-9]+/thing/regex",
|
|
},
|
|
skipAuthRoutes: []string{
|
|
"GET=^/foo/bar",
|
|
"POST=^/baz/[0-9]+/thing",
|
|
"^/all/methods$",
|
|
},
|
|
expectedRoutes: []expectedAllowedRoute{
|
|
{
|
|
method: "",
|
|
regexString: "^/foo/bar/regex",
|
|
},
|
|
{
|
|
method: "",
|
|
regexString: "^/baz/[0-9]+/thing/regex",
|
|
},
|
|
{
|
|
method: "GET",
|
|
regexString: "^/foo/bar",
|
|
},
|
|
{
|
|
method: "POST",
|
|
regexString: "^/baz/[0-9]+/thing",
|
|
},
|
|
{
|
|
method: "",
|
|
regexString: "^/all/methods$",
|
|
},
|
|
},
|
|
shouldError: false,
|
|
},
|
|
{
|
|
name: "Invalid skipAuthRegex entry",
|
|
skipAuthRegex: []string{
|
|
"^/foo/bar",
|
|
"^/baz/[0-9]+/thing",
|
|
"(bad[regex",
|
|
},
|
|
skipAuthRoutes: []string{},
|
|
expectedRoutes: []expectedAllowedRoute{},
|
|
shouldError: true,
|
|
},
|
|
{
|
|
name: "Invalid skipAuthRoutes entry",
|
|
skipAuthRegex: []string{},
|
|
skipAuthRoutes: []string{
|
|
"GET=^/foo/bar",
|
|
"POST=^/baz/[0-9]+/thing",
|
|
"^/all/methods$",
|
|
"PUT=(bad[regex",
|
|
},
|
|
expectedRoutes: []expectedAllowedRoute{},
|
|
shouldError: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
opts := &options.Options{
|
|
SkipAuthRegex: tc.skipAuthRegex,
|
|
SkipAuthRoutes: tc.skipAuthRoutes,
|
|
}
|
|
routes, err := buildRoutesAllowlist(opts)
|
|
if tc.shouldError {
|
|
assert.Error(t, err)
|
|
return
|
|
}
|
|
assert.NoError(t, err)
|
|
|
|
for i, route := range routes {
|
|
assert.Greater(t, len(tc.expectedRoutes), i)
|
|
assert.Equal(t, route.method, tc.expectedRoutes[i].method)
|
|
assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAllowedRequest(t *testing.T) {
|
|
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte("Allowed Request"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}))
|
|
t.Cleanup(upstreamServer.Close)
|
|
|
|
opts := baseTestOptions()
|
|
opts.UpstreamServers = options.UpstreamConfig{
|
|
Upstreams: []options.Upstream{
|
|
{
|
|
ID: upstreamServer.URL,
|
|
Path: "/",
|
|
URI: upstreamServer.URL,
|
|
},
|
|
},
|
|
}
|
|
opts.SkipAuthRegex = []string{
|
|
"^/skip/auth/regex$",
|
|
}
|
|
opts.SkipAuthRoutes = []string{
|
|
"GET=^/skip/auth/routes/get",
|
|
}
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
method string
|
|
url string
|
|
allowed bool
|
|
}{
|
|
{
|
|
name: "Regex GET allowed",
|
|
method: "GET",
|
|
url: "/skip/auth/regex",
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Regex POST allowed ",
|
|
method: "POST",
|
|
url: "/skip/auth/regex",
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Regex denied",
|
|
method: "GET",
|
|
url: "/wrong/denied",
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "Route allowed",
|
|
method: "GET",
|
|
url: "/skip/auth/routes/get",
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Route denied with wrong method",
|
|
method: "PATCH",
|
|
url: "/skip/auth/routes/get",
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "Route denied with wrong path",
|
|
method: "GET",
|
|
url: "/skip/auth/routes/wrong/path",
|
|
allowed: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req, err := http.NewRequest(tc.method, tc.url, nil)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req))
|
|
|
|
rw := httptest.NewRecorder()
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
if tc.allowed {
|
|
assert.Equal(t, 200, rw.Code)
|
|
assert.Equal(t, "Allowed Request", rw.Body.String())
|
|
} else {
|
|
assert.Equal(t, 403, rw.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProxyAllowedGroups(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
allowedGroups []string
|
|
groups []string
|
|
expectUnauthorized bool
|
|
}{
|
|
{"NoAllowedGroups", []string{}, []string{}, false},
|
|
{"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false},
|
|
{"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false},
|
|
{"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
emailAddress := "test"
|
|
created := time.Now()
|
|
|
|
session := &sessions.SessionState{
|
|
Groups: tt.groups,
|
|
Email: emailAddress,
|
|
AccessToken: "oauth_token",
|
|
CreatedAt: &created,
|
|
}
|
|
|
|
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
}))
|
|
t.Cleanup(upstreamServer.Close)
|
|
|
|
test, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Providers[0].AllowedGroups = tt.allowedGroups
|
|
opts.UpstreamServers = options.UpstreamConfig{
|
|
Upstreams: []options.Upstream{
|
|
{
|
|
ID: upstreamServer.URL,
|
|
Path: "/",
|
|
URI: upstreamServer.URL,
|
|
},
|
|
},
|
|
}
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
test.req, _ = http.NewRequest("GET", "/", nil)
|
|
|
|
test.req.Header.Add("accept", applicationJSON)
|
|
err = test.SaveSession(session)
|
|
assert.NoError(t, err)
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
|
|
if tt.expectUnauthorized {
|
|
assert.Equal(t, http.StatusForbidden, test.rw.Code)
|
|
} else {
|
|
assert.Equal(t, http.StatusOK, test.rw.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthOnlyAllowedGroups(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
allowedGroups []string
|
|
groups []string
|
|
querystring string
|
|
expectedStatusCode int
|
|
}{
|
|
{
|
|
name: "NoAllowedGroups",
|
|
allowedGroups: []string{},
|
|
groups: []string{},
|
|
querystring: "",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "NoAllowedGroupsUserHasGroups",
|
|
allowedGroups: []string{},
|
|
groups: []string{"a", "b"},
|
|
querystring: "",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserInAllowedGroup",
|
|
allowedGroups: []string{"a"},
|
|
groups: []string{"a", "b"},
|
|
querystring: "",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserNotInAllowedGroup",
|
|
allowedGroups: []string{"a"},
|
|
groups: []string{"c"},
|
|
querystring: "",
|
|
expectedStatusCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "UserInQuerystringGroup",
|
|
allowedGroups: []string{"a", "b"},
|
|
groups: []string{"a", "c"},
|
|
querystring: "?allowed_groups=a",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserInMultiParamQuerystringGroup",
|
|
allowedGroups: []string{"a", "b"},
|
|
groups: []string{"b"},
|
|
querystring: "?allowed_groups=a&allowed_groups=b,d",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserInOnlyQuerystringGroup",
|
|
allowedGroups: []string{},
|
|
groups: []string{"a", "c"},
|
|
querystring: "?allowed_groups=a,b",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserInDelimitedQuerystringGroup",
|
|
allowedGroups: []string{"a", "b", "c"},
|
|
groups: []string{"c"},
|
|
querystring: "?allowed_groups=a,c",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserNotInQuerystringGroup",
|
|
allowedGroups: []string{},
|
|
groups: []string{"c"},
|
|
querystring: "?allowed_groups=a,b",
|
|
expectedStatusCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "UserInConfigGroupNotInQuerystringGroup",
|
|
allowedGroups: []string{"a", "b", "c"},
|
|
groups: []string{"c"},
|
|
querystring: "?allowed_groups=a,b",
|
|
expectedStatusCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "UserInQuerystringGroupNotInConfigGroup",
|
|
allowedGroups: []string{"a", "b"},
|
|
groups: []string{"c"},
|
|
querystring: "?allowed_groups=b,c",
|
|
expectedStatusCode: http.StatusUnauthorized,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
emailAddress := "test"
|
|
created := time.Now()
|
|
|
|
session := &sessions.SessionState{
|
|
Groups: tc.groups,
|
|
Email: emailAddress,
|
|
AccessToken: "oauth_token",
|
|
CreatedAt: &created,
|
|
}
|
|
|
|
test, err := NewAuthOnlyEndpointTest(tc.querystring, func(opts *options.Options) {
|
|
opts.Providers[0].AllowedGroups = tc.allowedGroups
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = test.SaveSession(session)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
|
|
assert.Equal(t, tc.expectedStatusCode, test.rw.Code)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthOnlyAllowedGroupsWithSkipMethods(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
groups []string
|
|
method string
|
|
ip string
|
|
withSession bool
|
|
expectedStatusCode int
|
|
}{
|
|
{
|
|
name: "UserWithGroupSkipAuthPreflight",
|
|
groups: []string{"a", "c"},
|
|
method: "OPTIONS",
|
|
ip: "1.2.3.5:43670",
|
|
withSession: true,
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserWithGroupTrustedIp",
|
|
groups: []string{"a", "c"},
|
|
method: "GET",
|
|
ip: "1.2.3.4:43670",
|
|
withSession: true,
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserWithoutGroupSkipAuthPreflight",
|
|
groups: []string{"c"},
|
|
method: "OPTIONS",
|
|
ip: "1.2.3.5:43670",
|
|
withSession: true,
|
|
expectedStatusCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "UserWithoutGroupTrustedIp",
|
|
groups: []string{"c"},
|
|
method: "GET",
|
|
ip: "1.2.3.4:43670",
|
|
withSession: true,
|
|
expectedStatusCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "UserWithoutSessionSkipAuthPreflight",
|
|
method: "OPTIONS",
|
|
ip: "1.2.3.5:43670",
|
|
withSession: false,
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserWithoutSessionTrustedIp",
|
|
method: "GET",
|
|
ip: "1.2.3.4:43670",
|
|
withSession: false,
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest("?allowed_groups=a,b", func(opts *options.Options) {
|
|
opts.SkipAuthPreflight = true
|
|
opts.TrustedIPs = []string{"1.2.3.4"}
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
test.req.Method = tc.method
|
|
test.req.RemoteAddr = tc.ip
|
|
|
|
if tc.withSession {
|
|
created := time.Now()
|
|
session := &sessions.SessionState{
|
|
Groups: tc.groups,
|
|
Email: "test",
|
|
AccessToken: "oauth_token",
|
|
CreatedAt: &created,
|
|
}
|
|
err = test.SaveSession(session)
|
|
}
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
|
|
assert.Equal(t, tc.expectedStatusCode, test.rw.Code)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthOnlyAllowedEmailDomains(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
email string
|
|
querystring string
|
|
expectedStatusCode int
|
|
}{
|
|
{
|
|
name: "NotEmailRestriction",
|
|
email: "toto@example.com",
|
|
querystring: "",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserInAllowedEmailDomain",
|
|
email: "toto@example.com",
|
|
querystring: "?allowed_email_domains=example.com",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserNotInAllowedEmailDomain",
|
|
email: "toto@example.com",
|
|
querystring: "?allowed_email_domains=a.example.com",
|
|
expectedStatusCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "UserInAllowedEmailDomains",
|
|
email: "toto@example.com",
|
|
querystring: "?allowed_email_domains=a.example.com,b.example.com",
|
|
expectedStatusCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "UserInAllowedEmailDomains",
|
|
email: "toto@example.com",
|
|
querystring: "?allowed_email_domains=a.example.com,example.com",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserInAllowedEmailDomainWildcard",
|
|
email: "toto@foo.example.com",
|
|
querystring: "?allowed_email_domains=*.example.com",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserNotInAllowedEmailDomainWildcard",
|
|
email: "toto@example.com",
|
|
querystring: "?allowed_email_domains=*.a.example.com",
|
|
expectedStatusCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "UserInAllowedEmailDomainsWildcard",
|
|
email: "toto@example.com",
|
|
querystring: "?allowed_email_domains=*.a.example.com,*.b.example.com",
|
|
expectedStatusCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "UserInAllowedEmailDomainsWildcard",
|
|
email: "toto@c.example.com",
|
|
querystring: "?allowed_email_domains=a.b.c.example.com,*.c.example.com",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
groups := []string{}
|
|
|
|
created := time.Now()
|
|
|
|
session := &sessions.SessionState{
|
|
Groups: groups,
|
|
Email: tc.email,
|
|
AccessToken: "oauth_token",
|
|
CreatedAt: &created,
|
|
}
|
|
|
|
test, err := NewAuthOnlyEndpointTest(tc.querystring, func(opts *options.Options) {})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = test.SaveSession(session)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
|
|
assert.Equal(t, tc.expectedStatusCode, test.rw.Code)
|
|
})
|
|
}
|
|
}
|