mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-08 23:56:36 +02:00
Move AllowedGroups to DefaultProvider for default Authorize usage
This commit is contained in:
parent
e7ac793044
commit
eb58ea2ed9
@ -105,7 +105,6 @@ type OAuthProxy struct {
|
|||||||
trustedIPs *ip.NetSet
|
trustedIPs *ip.NetSet
|
||||||
Banner string
|
Banner string
|
||||||
Footer string
|
Footer string
|
||||||
AllowedGroups []string
|
|
||||||
|
|
||||||
sessionChain alice.Chain
|
sessionChain alice.Chain
|
||||||
headersChain alice.Chain
|
headersChain alice.Chain
|
||||||
@ -219,7 +218,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
Banner: opts.Banner,
|
Banner: opts.Banner,
|
||||||
Footer: opts.Footer,
|
Footer: opts.Footer,
|
||||||
SignInMessage: buildSignInMessage(opts),
|
SignInMessage: buildSignInMessage(opts),
|
||||||
AllowedGroups: opts.AllowedGroups,
|
|
||||||
|
|
||||||
basicAuthValidator: basicAuthValidator,
|
basicAuthValidator: basicAuthValidator,
|
||||||
displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm,
|
displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm,
|
||||||
@ -992,13 +990,12 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
invalidEmail := session.Email != "" && !p.Validator(session.Email)
|
invalidEmail := session.Email != "" && !p.Validator(session.Email)
|
||||||
invalidGroups := session != nil && !p.validateGroups(session.Groups)
|
|
||||||
authorized, err := p.provider.Authorize(req.Context(), session)
|
authorized, err := p.provider.Authorize(req.Context(), session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error with authorization: %v", err)
|
logger.Errorf("Error with authorization: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if invalidEmail || invalidGroups || !authorized {
|
if invalidEmail || !authorized {
|
||||||
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
|
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
|
||||||
// Invalid session, clear it
|
// Invalid session, clear it
|
||||||
err := p.ClearSessionCookie(rw, req)
|
err := p.ClearSessionCookie(rw, req)
|
||||||
@ -1037,23 +1034,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
|
|||||||
rw.Header().Set("Content-Type", applicationJSON)
|
rw.Header().Set("Content-Type", applicationJSON)
|
||||||
rw.WriteHeader(code)
|
rw.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) validateGroups(groups []string) bool {
|
|
||||||
if len(p.AllowedGroups) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
allowedGroups := map[string]struct{}{}
|
|
||||||
|
|
||||||
for _, group := range p.AllowedGroups {
|
|
||||||
allowedGroups[group] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, group := range groups {
|
|
||||||
if _, ok := allowedGroups[group]; ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
@ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pcTest.proxy.provider = &TestProvider{
|
pcTest.proxy.provider = &TestProvider{
|
||||||
|
ProviderData: &providers.ProviderData{},
|
||||||
ValidToken: opts.providerValidateCookieResponse,
|
ValidToken: opts.providerValidateCookieResponse,
|
||||||
}
|
}
|
||||||
|
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups)
|
||||||
|
|
||||||
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
||||||
// access_token validation.
|
// access_token validation.
|
||||||
@ -1132,10 +1134,7 @@ func TestUserInfoEndpointAccepted(t *testing.T) {
|
|||||||
err = test.SaveSession(startSession)
|
err = test.SaveSession(startSession)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
test.proxy.ServeHTTP(test.rw, test.req)
|
return
|
||||||
assert.Equal(t, http.StatusOK, test.rw.Code)
|
|
||||||
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
||||||
assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
||||||
@ -1284,6 +1283,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
pcTest.proxy.provider = &TestProvider{
|
pcTest.proxy.provider = &TestProvider{
|
||||||
|
ProviderData: &providers.ProviderData{},
|
||||||
ValidToken: true,
|
ValidToken: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1376,6 +1376,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
pcTest.proxy.provider = &TestProvider{
|
pcTest.proxy.provider = &TestProvider{
|
||||||
|
ProviderData: &providers.ProviderData{},
|
||||||
ValidToken: true,
|
ValidToken: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1455,6 +1456,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
pcTest.proxy.provider = &TestProvider{
|
pcTest.proxy.provider = &TestProvider{
|
||||||
|
ProviderData: &providers.ProviderData{},
|
||||||
ValidToken: true,
|
ValidToken: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
|||||||
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
|
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
|
||||||
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
|
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
|
||||||
|
|
||||||
|
p.SetAllowedGroups(o.AllowedGroups)
|
||||||
|
|
||||||
provider := providers.New(o.ProviderType, p)
|
provider := providers.New(o.ProviderType, p)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType))
|
msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType))
|
||||||
|
@ -26,6 +26,10 @@ type ProviderData struct {
|
|||||||
ClientSecretFile string
|
ClientSecretFile string
|
||||||
Scope string
|
Scope string
|
||||||
Prompt string
|
Prompt string
|
||||||
|
|
||||||
|
// Universal Group authorization data structure
|
||||||
|
// any provider can set to consume
|
||||||
|
AllowedGroups map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Data returns the ProviderData
|
// Data returns the ProviderData
|
||||||
@ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) {
|
|||||||
return string(fileClientSecret), nil
|
return string(fileClientSecret), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAllowedGroups organizes a group list into the AllowedGroups map
|
||||||
|
// to be consumed by Authorize implementations
|
||||||
|
func (p *ProviderData) SetAllowedGroups(groups []string) {
|
||||||
|
p.AllowedGroups = map[string]struct{}{}
|
||||||
|
for _, group := range groups {
|
||||||
|
p.AllowedGroups[group] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type providerDefaults struct {
|
type providerDefaults struct {
|
||||||
name string
|
name string
|
||||||
loginURL *url.URL
|
loginURL *url.URL
|
||||||
|
@ -92,12 +92,6 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta
|
|||||||
return "", ErrNotImplemented
|
return "", ErrNotImplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateGroup validates that the provided email exists in the configured provider
|
|
||||||
// email group(s).
|
|
||||||
func (p *ProviderData) ValidateGroup(_ string) bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnrichSessionState is called after Redeem to allow providers to enrich session fields
|
// EnrichSessionState is called after Redeem to allow providers to enrich session fields
|
||||||
// such as User, Email, Groups with provider specific API calls.
|
// such as User, Email, Groups with provider specific API calls.
|
||||||
func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error {
|
func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error {
|
||||||
@ -107,7 +101,17 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session
|
|||||||
// Authorize performs global authorization on an authenticated session.
|
// Authorize performs global authorization on an authenticated session.
|
||||||
// This is not used for fine-grained per route authorization rules.
|
// This is not used for fine-grained per route authorization rules.
|
||||||
func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
|
if len(p.AllowedGroups) == 0 {
|
||||||
return true, nil
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range s.Groups {
|
||||||
|
if _, ok := p.AllowedGroups[group]; ok {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) {
|
|||||||
s := &sessions.SessionState{}
|
s := &sessions.SessionState{}
|
||||||
assert.NoError(t, p.EnrichSessionState(context.Background(), s))
|
assert.NoError(t, p.EnrichSessionState(context.Background(), s))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProviderDataAuthorize(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
allowedGroups []string
|
||||||
|
groups []string
|
||||||
|
expectedAuthZ bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "NoAllowedGroups",
|
||||||
|
allowedGroups: []string{},
|
||||||
|
groups: []string{},
|
||||||
|
expectedAuthZ: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NoAllowedGroupsUserHasGroups",
|
||||||
|
allowedGroups: []string{},
|
||||||
|
groups: []string{"foo", "bar"},
|
||||||
|
expectedAuthZ: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserInAllowedGroup",
|
||||||
|
allowedGroups: []string{"foo"},
|
||||||
|
groups: []string{"foo", "bar"},
|
||||||
|
expectedAuthZ: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserNotInAllowedGroup",
|
||||||
|
allowedGroups: []string{"bar"},
|
||||||
|
groups: []string{"baz", "foo"},
|
||||||
|
expectedAuthZ: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
|
||||||
|
session := &sessions.SessionState{
|
||||||
|
Groups: tc.groups,
|
||||||
|
}
|
||||||
|
p := &ProviderData{}
|
||||||
|
p.SetAllowedGroups(tc.allowedGroups)
|
||||||
|
|
||||||
|
authorized, err := p.Authorize(context.Background(), session)
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
g.Expect(authorized).To(Equal(tc.expectedAuthZ))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user