diff --git a/oauthproxy.go b/oauthproxy.go index 383b3bad..7d69cd5e 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -105,7 +105,6 @@ type OAuthProxy struct { trustedIPs *ip.NetSet Banner string Footer string - AllowedGroups []string sessionChain alice.Chain headersChain alice.Chain @@ -219,7 +218,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr Banner: opts.Banner, Footer: opts.Footer, SignInMessage: buildSignInMessage(opts), - AllowedGroups: opts.AllowedGroups, basicAuthValidator: basicAuthValidator, displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, @@ -992,13 +990,12 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R } invalidEmail := session.Email != "" && !p.Validator(session.Email) - invalidGroups := session != nil && !p.validateGroups(session.Groups) authorized, err := p.provider.Authorize(req.Context(), session) if err != nil { logger.Errorf("Error with authorization: %v", err) } - if invalidEmail || invalidGroups || !authorized { + if invalidEmail || !authorized { logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) // Invalid session, clear it err := p.ClearSessionCookie(rw, req) @@ -1037,23 +1034,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { rw.Header().Set("Content-Type", applicationJSON) rw.WriteHeader(code) } - -func (p *OAuthProxy) validateGroups(groups []string) bool { - if len(p.AllowedGroups) == 0 { - return true - } - - allowedGroups := map[string]struct{}{} - - for _, group := range p.AllowedGroups { - allowedGroups[group] = struct{}{} - } - - for _, group := range groups { - if _, ok := allowedGroups[group]; ok { - return true - } - } - - return false -} diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 1736b39f..c06ff3be 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi return nil, err } pcTest.proxy.provider = &TestProvider{ - ValidToken: opts.providerValidateCookieResponse, + ProviderData: &providers.ProviderData{}, + ValidToken: opts.providerValidateCookieResponse, } + pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups) // Now, zero-out proxy.CookieRefresh for the cases that don't involve // access_token validation. @@ -1132,10 +1134,7 @@ func TestUserInfoEndpointAccepted(t *testing.T) { err = test.SaveSession(startSession) assert.NoError(t, err) - test.proxy.ServeHTTP(test.rw, test.req) - assert.Equal(t, http.StatusOK, test.rw.Code) - bodyBytes, _ := ioutil.ReadAll(test.rw.Body) - assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes)) + return } func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { @@ -1284,7 +1283,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ - ValidToken: true, + ProviderData: &providers.ProviderData{}, + ValidToken: true, } pcTest.validateUser = true @@ -1376,7 +1376,8 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ - ValidToken: true, + ProviderData: &providers.ProviderData{}, + ValidToken: true, } pcTest.validateUser = true @@ -1455,7 +1456,8 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ - ValidToken: true, + ProviderData: &providers.ProviderData{}, + ValidToken: true, } pcTest.validateUser = true diff --git a/pkg/validation/options.go b/pkg/validation/options.go index fffd94ae..121fe6b4 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) + p.SetAllowedGroups(o.AllowedGroups) + provider := providers.New(o.ProviderType, p) if provider == nil { msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType)) diff --git a/providers/provider_data.go b/providers/provider_data.go index 5fce04ec..e446bcd6 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -26,6 +26,10 @@ type ProviderData struct { ClientSecretFile string Scope string Prompt string + + // Universal Group authorization data structure + // any provider can set to consume + AllowedGroups map[string]struct{} } // Data returns the ProviderData @@ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) { return string(fileClientSecret), nil } +// SetAllowedGroups organizes a group list into the AllowedGroups map +// to be consumed by Authorize implementations +func (p *ProviderData) SetAllowedGroups(groups []string) { + p.AllowedGroups = map[string]struct{}{} + for _, group := range groups { + p.AllowedGroups[group] = struct{}{} + } +} + type providerDefaults struct { name string loginURL *url.URL diff --git a/providers/provider_default.go b/providers/provider_default.go index 6f3892e5..f7d5ac0e 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -92,12 +92,6 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta return "", ErrNotImplemented } -// ValidateGroup validates that the provided email exists in the configured provider -// email group(s). -func (p *ProviderData) ValidateGroup(_ string) bool { - return true -} - // EnrichSessionState is called after Redeem to allow providers to enrich session fields // such as User, Email, Groups with provider specific API calls. func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error { @@ -107,7 +101,17 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session // Authorize performs global authorization on an authenticated session. // This is not used for fine-grained per route authorization rules. func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) { - return true, nil + if len(p.AllowedGroups) == 0 { + return true, nil + } + + for _, group := range s.Groups { + if _, ok := p.AllowedGroups[group]; ok { + return true, nil + } + } + + return false, nil } // ValidateSessionState validates the AccessToken diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index f04fe607..c9e87b33 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" ) @@ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) { s := &sessions.SessionState{} assert.NoError(t, p.EnrichSessionState(context.Background(), s)) } + +func TestProviderDataAuthorize(t *testing.T) { + testCases := []struct { + name string + allowedGroups []string + groups []string + expectedAuthZ bool + }{ + { + name: "NoAllowedGroups", + allowedGroups: []string{}, + groups: []string{}, + expectedAuthZ: true, + }, + { + name: "NoAllowedGroupsUserHasGroups", + allowedGroups: []string{}, + groups: []string{"foo", "bar"}, + expectedAuthZ: true, + }, + { + name: "UserInAllowedGroup", + allowedGroups: []string{"foo"}, + groups: []string{"foo", "bar"}, + expectedAuthZ: true, + }, + { + name: "UserNotInAllowedGroup", + allowedGroups: []string{"bar"}, + groups: []string{"baz", "foo"}, + expectedAuthZ: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + session := &sessions.SessionState{ + Groups: tc.groups, + } + p := &ProviderData{} + p.SetAllowedGroups(tc.allowedGroups) + + authorized, err := p.Authorize(context.Background(), session) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(authorized).To(Equal(tc.expectedAuthZ)) + }) + } +}