mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-17 21:17:53 +02:00
Move AllowedGroups to DefaultProvider for default Authorize usage
This commit is contained in:
parent
e7ac793044
commit
eb58ea2ed9
@ -105,7 +105,6 @@ type OAuthProxy struct {
|
||||
trustedIPs *ip.NetSet
|
||||
Banner string
|
||||
Footer string
|
||||
AllowedGroups []string
|
||||
|
||||
sessionChain alice.Chain
|
||||
headersChain alice.Chain
|
||||
@ -219,7 +218,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
Banner: opts.Banner,
|
||||
Footer: opts.Footer,
|
||||
SignInMessage: buildSignInMessage(opts),
|
||||
AllowedGroups: opts.AllowedGroups,
|
||||
|
||||
basicAuthValidator: basicAuthValidator,
|
||||
displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm,
|
||||
@ -992,13 +990,12 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
|
||||
}
|
||||
|
||||
invalidEmail := session.Email != "" && !p.Validator(session.Email)
|
||||
invalidGroups := session != nil && !p.validateGroups(session.Groups)
|
||||
authorized, err := p.provider.Authorize(req.Context(), session)
|
||||
if err != nil {
|
||||
logger.Errorf("Error with authorization: %v", err)
|
||||
}
|
||||
|
||||
if invalidEmail || invalidGroups || !authorized {
|
||||
if invalidEmail || !authorized {
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
|
||||
// Invalid session, clear it
|
||||
err := p.ClearSessionCookie(rw, req)
|
||||
@ -1037,23 +1034,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
|
||||
rw.Header().Set("Content-Type", applicationJSON)
|
||||
rw.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) validateGroups(groups []string) bool {
|
||||
if len(p.AllowedGroups) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
allowedGroups := map[string]struct{}{}
|
||||
|
||||
for _, group := range p.AllowedGroups {
|
||||
allowedGroups[group] = struct{}{}
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if _, ok := allowedGroups[group]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
|
||||
return nil, err
|
||||
}
|
||||
pcTest.proxy.provider = &TestProvider{
|
||||
ValidToken: opts.providerValidateCookieResponse,
|
||||
ProviderData: &providers.ProviderData{},
|
||||
ValidToken: opts.providerValidateCookieResponse,
|
||||
}
|
||||
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups)
|
||||
|
||||
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
||||
// access_token validation.
|
||||
@ -1132,10 +1134,7 @@ func TestUserInfoEndpointAccepted(t *testing.T) {
|
||||
err = test.SaveSession(startSession)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.proxy.ServeHTTP(test.rw, test.req)
|
||||
assert.Equal(t, http.StatusOK, test.rw.Code)
|
||||
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
||||
assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes))
|
||||
return
|
||||
}
|
||||
|
||||
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
||||
@ -1284,7 +1283,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pcTest.proxy.provider = &TestProvider{
|
||||
ValidToken: true,
|
||||
ProviderData: &providers.ProviderData{},
|
||||
ValidToken: true,
|
||||
}
|
||||
|
||||
pcTest.validateUser = true
|
||||
@ -1376,7 +1376,8 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pcTest.proxy.provider = &TestProvider{
|
||||
ValidToken: true,
|
||||
ProviderData: &providers.ProviderData{},
|
||||
ValidToken: true,
|
||||
}
|
||||
|
||||
pcTest.validateUser = true
|
||||
@ -1455,7 +1456,8 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pcTest.proxy.provider = &TestProvider{
|
||||
ValidToken: true,
|
||||
ProviderData: &providers.ProviderData{},
|
||||
ValidToken: true,
|
||||
}
|
||||
|
||||
pcTest.validateUser = true
|
||||
|
@ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
||||
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
|
||||
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
|
||||
|
||||
p.SetAllowedGroups(o.AllowedGroups)
|
||||
|
||||
provider := providers.New(o.ProviderType, p)
|
||||
if provider == nil {
|
||||
msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType))
|
||||
|
@ -26,6 +26,10 @@ type ProviderData struct {
|
||||
ClientSecretFile string
|
||||
Scope string
|
||||
Prompt string
|
||||
|
||||
// Universal Group authorization data structure
|
||||
// any provider can set to consume
|
||||
AllowedGroups map[string]struct{}
|
||||
}
|
||||
|
||||
// Data returns the ProviderData
|
||||
@ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) {
|
||||
return string(fileClientSecret), nil
|
||||
}
|
||||
|
||||
// SetAllowedGroups organizes a group list into the AllowedGroups map
|
||||
// to be consumed by Authorize implementations
|
||||
func (p *ProviderData) SetAllowedGroups(groups []string) {
|
||||
p.AllowedGroups = map[string]struct{}{}
|
||||
for _, group := range groups {
|
||||
p.AllowedGroups[group] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
type providerDefaults struct {
|
||||
name string
|
||||
loginURL *url.URL
|
||||
|
@ -92,12 +92,6 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta
|
||||
return "", ErrNotImplemented
|
||||
}
|
||||
|
||||
// ValidateGroup validates that the provided email exists in the configured provider
|
||||
// email group(s).
|
||||
func (p *ProviderData) ValidateGroup(_ string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// EnrichSessionState is called after Redeem to allow providers to enrich session fields
|
||||
// such as User, Email, Groups with provider specific API calls.
|
||||
func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error {
|
||||
@ -107,7 +101,17 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session
|
||||
// Authorize performs global authorization on an authenticated session.
|
||||
// This is not used for fine-grained per route authorization rules.
|
||||
func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
return true, nil
|
||||
if len(p.AllowedGroups) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for _, group := range s.Groups {
|
||||
if _, ok := p.AllowedGroups[group]; ok {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) {
|
||||
s := &sessions.SessionState{}
|
||||
assert.NoError(t, p.EnrichSessionState(context.Background(), s))
|
||||
}
|
||||
|
||||
func TestProviderDataAuthorize(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
allowedGroups []string
|
||||
groups []string
|
||||
expectedAuthZ bool
|
||||
}{
|
||||
{
|
||||
name: "NoAllowedGroups",
|
||||
allowedGroups: []string{},
|
||||
groups: []string{},
|
||||
expectedAuthZ: true,
|
||||
},
|
||||
{
|
||||
name: "NoAllowedGroupsUserHasGroups",
|
||||
allowedGroups: []string{},
|
||||
groups: []string{"foo", "bar"},
|
||||
expectedAuthZ: true,
|
||||
},
|
||||
{
|
||||
name: "UserInAllowedGroup",
|
||||
allowedGroups: []string{"foo"},
|
||||
groups: []string{"foo", "bar"},
|
||||
expectedAuthZ: true,
|
||||
},
|
||||
{
|
||||
name: "UserNotInAllowedGroup",
|
||||
allowedGroups: []string{"bar"},
|
||||
groups: []string{"baz", "foo"},
|
||||
expectedAuthZ: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
session := &sessions.SessionState{
|
||||
Groups: tc.groups,
|
||||
}
|
||||
p := &ProviderData{}
|
||||
p.SetAllowedGroups(tc.allowedGroups)
|
||||
|
||||
authorized, err := p.Authorize(context.Background(), session)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
g.Expect(authorized).To(Equal(tc.expectedAuthZ))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user