You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2026-05-22 10:15:21 +02:00
Streamline Google to use default Authorize
This commit is contained in:
+17
-32
@@ -27,11 +27,14 @@ type GoogleProvider struct {
|
||||
*ProviderData
|
||||
|
||||
RedeemRefreshURL *url.URL
|
||||
|
||||
// GroupValidator is a function that determines if the user in the passed
|
||||
// session is a member of any of the configured Google groups.
|
||||
GroupValidator func(*sessions.SessionState, bool) bool
|
||||
|
||||
allowedGroups map[string]struct{}
|
||||
//
|
||||
// This hits the Google API for each group, so it is called on Redeem &
|
||||
// Refresh. `Authorize` uses the results of this saved in `session.Groups`
|
||||
// Since it is called on every request.
|
||||
GroupValidator func(*sessions.SessionState) bool
|
||||
}
|
||||
|
||||
var _ Provider = (*GoogleProvider)(nil)
|
||||
@@ -89,7 +92,7 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
|
||||
ProviderData: p,
|
||||
// Set a default GroupValidator to just always return valid (true), it will
|
||||
// be overwritten if we configured a Google group restriction.
|
||||
GroupValidator: func(*sessions.SessionState, bool) bool {
|
||||
GroupValidator: func(*sessions.SessionState) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
@@ -172,45 +175,27 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
||||
Email: c.Email,
|
||||
User: c.Subject,
|
||||
}
|
||||
p.GroupValidator(s, true)
|
||||
p.GroupValidator(s)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
return p.GroupValidator(s, false), nil
|
||||
}
|
||||
|
||||
// SetGroupRestriction configures the GoogleProvider to restrict access to the
|
||||
// specified group(s). AdminEmail has to be an administrative email on the domain that is
|
||||
// checked. CredentialsFile is the path to a json file containing a Google service
|
||||
// account credentials.
|
||||
func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) {
|
||||
adminService := getAdminService(adminEmail, credentialsReader)
|
||||
for _, group := range groups {
|
||||
p.allowedGroups[group] = struct{}{}
|
||||
}
|
||||
|
||||
p.GroupValidator = func(s *sessions.SessionState, sync bool) bool {
|
||||
if sync {
|
||||
// Reset our saved Groups in case membership changed
|
||||
s.Groups = make([]string, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
if userInGroup(adminService, group, s.Email) {
|
||||
s.Groups = append(s.Groups, group)
|
||||
}
|
||||
}
|
||||
return len(s.Groups) > 0
|
||||
}
|
||||
|
||||
// Don't resync with Google, handles when OAuth2-Proxy settings
|
||||
// alter allowed groups but existing sessions are still valid
|
||||
for _, group := range s.Groups {
|
||||
if _, ok := p.allowedGroups[group]; ok {
|
||||
return true
|
||||
p.GroupValidator = func(s *sessions.SessionState) bool {
|
||||
// Reset our saved Groups in case membership changed
|
||||
// This is used by `Authorize` on every request
|
||||
s.Groups = make([]string, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
if userInGroup(adminService, group, s.Email) {
|
||||
s.Groups = append(s.Groups, group)
|
||||
}
|
||||
}
|
||||
return false
|
||||
return len(s.Groups) > 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,7 +267,7 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
||||
}
|
||||
|
||||
// re-check that the user is in the proper google group(s)
|
||||
if !p.GroupValidator(s, true) {
|
||||
if !p.GroupValidator(s) {
|
||||
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/assert"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
@@ -110,19 +110,19 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
|
||||
assert.Equal(t, "refresh12345", session.RefreshToken)
|
||||
}
|
||||
|
||||
func TestGoogleProviderAuthorize(t *testing.T) {
|
||||
func TestGoogleProviderGroupValidator(t *testing.T) {
|
||||
const sessionEmail = "michael.bland@gsa.gov"
|
||||
|
||||
testCases := map[string]struct {
|
||||
session *sessions.SessionState
|
||||
validatorFunc func(*sessions.SessionState, bool) bool
|
||||
validatorFunc func(*sessions.SessionState) bool
|
||||
expectedAuthZ bool
|
||||
}{
|
||||
"Email is authorized with GroupValidator": {
|
||||
session: &sessions.SessionState{
|
||||
Email: sessionEmail,
|
||||
},
|
||||
validatorFunc: func(s *sessions.SessionState, _ bool) bool {
|
||||
validatorFunc: func(s *sessions.SessionState) bool {
|
||||
return s.Email == sessionEmail
|
||||
},
|
||||
expectedAuthZ: true,
|
||||
@@ -131,7 +131,7 @@ func TestGoogleProviderAuthorize(t *testing.T) {
|
||||
session: &sessions.SessionState{
|
||||
Email: sessionEmail,
|
||||
},
|
||||
validatorFunc: func(s *sessions.SessionState, _ bool) bool {
|
||||
validatorFunc: func(s *sessions.SessionState) bool {
|
||||
return s.Email != sessionEmail
|
||||
},
|
||||
expectedAuthZ: false,
|
||||
@@ -151,9 +151,7 @@ func TestGoogleProviderAuthorize(t *testing.T) {
|
||||
if tc.validatorFunc != nil {
|
||||
p.GroupValidator = tc.validatorFunc
|
||||
}
|
||||
authorized, err := p.Authorize(context.Background(), tc.session)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
g.Expect(authorized).To(Equal(tc.expectedAuthZ))
|
||||
g.Expect(p.GroupValidator(tc.session)).To(Equal(tc.expectedAuthZ))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) {
|
||||
// SetAllowedGroups organizes a group list into the AllowedGroups map
|
||||
// to be consumed by Authorize implementations
|
||||
func (p *ProviderData) SetAllowedGroups(groups []string) {
|
||||
p.AllowedGroups = map[string]struct{}{}
|
||||
p.AllowedGroups = make(map[string]struct{}, len(groups))
|
||||
for _, group := range groups {
|
||||
p.AllowedGroups[group] = struct{}{}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user