You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-07-13 01:40:48 +02:00
Replace ValidateGroup with Authorize for Provider
This commit is contained in:
@ -909,7 +909,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// set cookie, or deny
|
// set cookie, or deny
|
||||||
if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) {
|
if p.Validator(session.Email) {
|
||||||
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
|
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
|
||||||
err := p.SaveSession(rw, req, session)
|
err := p.SaveSession(rw, req, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -991,15 +991,19 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
|
|||||||
return nil, ErrNeedsLogin
|
return nil, ErrNeedsLogin
|
||||||
}
|
}
|
||||||
|
|
||||||
invalidEmail := session != nil && session.Email != "" && !p.Validator(session.Email)
|
invalidEmail := session.Email != "" && !p.Validator(session.Email)
|
||||||
invalidGroups := session != nil && !p.validateGroups(session.Groups)
|
invalidGroups := session != nil && !p.validateGroups(session.Groups)
|
||||||
|
authorized, err := p.provider.Authorize(req.Context(), session)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error with authorization: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if invalidEmail || invalidGroups {
|
if invalidEmail || invalidGroups || !authorized {
|
||||||
logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
|
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
|
||||||
// Invalid session, clear it
|
// Invalid session, clear it
|
||||||
err := p.ClearSessionCookie(rw, req)
|
err := p.ClearSessionCookie(rw, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("Error clearing session cookie: %v", err)
|
logger.Errorf("Error clearing session cookie: %v", err)
|
||||||
}
|
}
|
||||||
return nil, ErrNeedsLogin
|
return nil, ErrNeedsLogin
|
||||||
}
|
}
|
||||||
|
@ -25,10 +25,13 @@ import (
|
|||||||
// GoogleProvider represents an Google based Identity Provider
|
// GoogleProvider represents an Google based Identity Provider
|
||||||
type GoogleProvider struct {
|
type GoogleProvider struct {
|
||||||
*ProviderData
|
*ProviderData
|
||||||
|
|
||||||
RedeemRefreshURL *url.URL
|
RedeemRefreshURL *url.URL
|
||||||
// GroupValidator is a function that determines if the passed email is in
|
// GroupValidator is a function that determines if the user in the passed
|
||||||
// the configured Google group.
|
// session is a member of any of the configured Google groups.
|
||||||
GroupValidator func(string) bool
|
GroupValidator func(*sessions.SessionState, bool) bool
|
||||||
|
|
||||||
|
allowedGroups map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Provider = (*GoogleProvider)(nil)
|
var _ Provider = (*GoogleProvider)(nil)
|
||||||
@ -86,7 +89,7 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
|
|||||||
ProviderData: p,
|
ProviderData: p,
|
||||||
// Set a default GroupValidator to just always return valid (true), it will
|
// Set a default GroupValidator to just always return valid (true), it will
|
||||||
// be overwritten if we configured a Google group restriction.
|
// be overwritten if we configured a Google group restriction.
|
||||||
GroupValidator: func(email string) bool {
|
GroupValidator: func(*sessions.SessionState, bool) bool {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -118,14 +121,14 @@ func claimsFromIDToken(idToken string) (*claims, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
err := errors.New("missing code")
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
@ -155,12 +158,12 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
|||||||
|
|
||||||
c, err := claimsFromIDToken(jsonResponse.IDToken)
|
c, err := claimsFromIDToken(jsonResponse.IDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
||||||
s = &sessions.SessionState{
|
s := &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
CreatedAt: &created,
|
CreatedAt: &created,
|
||||||
@ -169,7 +172,13 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
|||||||
Email: c.Email,
|
Email: c.Email,
|
||||||
User: c.Subject,
|
User: c.Subject,
|
||||||
}
|
}
|
||||||
return
|
p.GroupValidator(s, true)
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GoogleProvider) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
|
return p.GroupValidator(s, false), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetGroupRestriction configures the GoogleProvider to restrict access to the
|
// SetGroupRestriction configures the GoogleProvider to restrict access to the
|
||||||
@ -178,8 +187,30 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
|||||||
// account credentials.
|
// account credentials.
|
||||||
func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) {
|
func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) {
|
||||||
adminService := getAdminService(adminEmail, credentialsReader)
|
adminService := getAdminService(adminEmail, credentialsReader)
|
||||||
p.GroupValidator = func(email string) bool {
|
for _, group := range groups {
|
||||||
return userInGroup(adminService, groups, email)
|
p.allowedGroups[group] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.GroupValidator = func(s *sessions.SessionState, sync bool) bool {
|
||||||
|
if sync {
|
||||||
|
// Reset our saved Groups in case membership changed
|
||||||
|
s.Groups = make([]string, 0, len(groups))
|
||||||
|
for _, group := range groups {
|
||||||
|
if userInGroup(adminService, group, s.Email) {
|
||||||
|
s.Groups = append(s.Groups, group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(s.Groups) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't resync with Google, handles when OAuth2-Proxy settings
|
||||||
|
// alter allowed groups but existing sessions are still valid
|
||||||
|
for _, group := range s.Groups {
|
||||||
|
if _, ok := p.allowedGroups[group]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,52 +234,41 @@ func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Serv
|
|||||||
return adminService
|
return adminService
|
||||||
}
|
}
|
||||||
|
|
||||||
func userInGroup(service *admin.Service, groups []string, email string) bool {
|
func userInGroup(service *admin.Service, group string, email string) bool {
|
||||||
for _, group := range groups {
|
// Use the HasMember API to checking for the user's presence in each group or nested subgroups
|
||||||
// Use the HasMember API to checking for the user's presence in each group or nested subgroups
|
req := service.Members.HasMember(group, email)
|
||||||
req := service.Members.HasMember(group, email)
|
r, err := req.Do()
|
||||||
|
if err == nil {
|
||||||
|
return r.IsMember
|
||||||
|
}
|
||||||
|
|
||||||
|
gerr, ok := err.(*googleapi.Error)
|
||||||
|
switch {
|
||||||
|
case ok && gerr.Code == 404:
|
||||||
|
logger.Errorf("error checking membership in group %s: group does not exist", group)
|
||||||
|
case ok && gerr.Code == 400:
|
||||||
|
// It is possible for Members.HasMember to return false even if the email is a group member.
|
||||||
|
// One case that can cause this is if the user email is from a different domain than the group,
|
||||||
|
// e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error
|
||||||
|
// from the HasMember API. In that case, attempt to query the member object directly from the group.
|
||||||
|
req := service.Members.Get(group, email)
|
||||||
r, err := req.Do()
|
r, err := req.Do()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
gerr, ok := err.(*googleapi.Error)
|
logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group)
|
||||||
switch {
|
return false
|
||||||
case ok && gerr.Code == 404:
|
|
||||||
logger.Errorf("error checking membership in group %s: group does not exist", group)
|
|
||||||
case ok && gerr.Code == 400:
|
|
||||||
// It is possible for Members.HasMember to return false even if the email is a group member.
|
|
||||||
// One case that can cause this is if the user email is from a different domain than the group,
|
|
||||||
// e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error
|
|
||||||
// from the HasMember API. In that case, attempt to query the member object directly from the group.
|
|
||||||
req := service.Members.Get(group, email)
|
|
||||||
r, err := req.Do()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the non-domain user is found within the group, still verify that they are "ACTIVE".
|
|
||||||
// Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN").
|
|
||||||
if r.Status == "ACTIVE" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
logger.Errorf("error checking group membership: %v", err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
if r.IsMember {
|
|
||||||
|
// If the non-domain user is found within the group, still verify that they are "ACTIVE".
|
||||||
|
// Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN").
|
||||||
|
if r.Status == "ACTIVE" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
logger.Errorf("error checking group membership: %v", err)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateGroup validates that the provided email exists in the configured Google
|
|
||||||
// group(s).
|
|
||||||
func (p *GoogleProvider) ValidateGroup(email string) bool {
|
|
||||||
return p.GroupValidator(email)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new ID token if required
|
// RefreshToken to fetch a new ID token if required
|
||||||
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
@ -262,7 +282,7 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
|||||||
}
|
}
|
||||||
|
|
||||||
// re-check that the user is in the proper google group(s)
|
// re-check that the user is in the proper google group(s)
|
||||||
if !p.ValidateGroup(s.Email) {
|
if !p.GroupValidator(s, true) {
|
||||||
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
|
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
admin "google.golang.org/api/admin/directory/v1"
|
admin "google.golang.org/api/admin/directory/v1"
|
||||||
@ -109,21 +110,52 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
|
|||||||
assert.Equal(t, "refresh12345", session.RefreshToken)
|
assert.Equal(t, "refresh12345", session.RefreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGoogleProviderValidateGroup(t *testing.T) {
|
func TestGoogleProviderAuthorize(t *testing.T) {
|
||||||
p := newGoogleProvider()
|
const sessionEmail = "michael.bland@gsa.gov"
|
||||||
p.GroupValidator = func(email string) bool {
|
|
||||||
return email == "michael.bland@gsa.gov"
|
|
||||||
}
|
|
||||||
assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov"))
|
|
||||||
p.GroupValidator = func(email string) bool {
|
|
||||||
return email != "michael.bland@gsa.gov"
|
|
||||||
}
|
|
||||||
assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGoogleProviderWithoutValidateGroup(t *testing.T) {
|
testCases := map[string]struct {
|
||||||
p := newGoogleProvider()
|
session *sessions.SessionState
|
||||||
assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov"))
|
validatorFunc func(*sessions.SessionState, bool) bool
|
||||||
|
expectedAuthZ bool
|
||||||
|
}{
|
||||||
|
"Email is authorized with GroupValidator": {
|
||||||
|
session: &sessions.SessionState{
|
||||||
|
Email: sessionEmail,
|
||||||
|
},
|
||||||
|
validatorFunc: func(s *sessions.SessionState, _ bool) bool {
|
||||||
|
return s.Email == sessionEmail
|
||||||
|
},
|
||||||
|
expectedAuthZ: true,
|
||||||
|
},
|
||||||
|
"Email is denied with GroupValidator": {
|
||||||
|
session: &sessions.SessionState{
|
||||||
|
Email: sessionEmail,
|
||||||
|
},
|
||||||
|
validatorFunc: func(s *sessions.SessionState, _ bool) bool {
|
||||||
|
return s.Email != sessionEmail
|
||||||
|
},
|
||||||
|
expectedAuthZ: false,
|
||||||
|
},
|
||||||
|
"Default does no authorization checks": {
|
||||||
|
session: &sessions.SessionState{
|
||||||
|
Email: sessionEmail,
|
||||||
|
},
|
||||||
|
validatorFunc: nil,
|
||||||
|
expectedAuthZ: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
p := newGoogleProvider()
|
||||||
|
if tc.validatorFunc != nil {
|
||||||
|
p.GroupValidator = tc.validatorFunc
|
||||||
|
}
|
||||||
|
authorized, err := p.Authorize(context.Background(), tc.session)
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
g.Expect(authorized).To(Equal(tc.expectedAuthZ))
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -196,7 +228,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGoogleProviderUserInGroup(t *testing.T) {
|
func TestGoogleProvider_userInGroup(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" {
|
if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" {
|
||||||
fmt.Fprintln(w, `{"isMember": true}`)
|
fmt.Fprintln(w, `{"isMember": true}`)
|
||||||
@ -233,18 +265,19 @@ func TestGoogleProviderUserInGroup(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
service, err := admin.NewService(ctx, option.WithHTTPClient(client))
|
service, err := admin.NewService(ctx, option.WithHTTPClient(client))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
service.BasePath = ts.URL
|
service.BasePath = ts.URL
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
|
|
||||||
result := userInGroup(service, []string{"group@example.com"}, "member-in-domain@example.com")
|
result := userInGroup(service, "group@example.com", "member-in-domain@example.com")
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
|
|
||||||
result = userInGroup(service, []string{"group@example.com"}, "member-out-of-domain@otherexample.com")
|
result = userInGroup(service, "group@example.com", "member-out-of-domain@otherexample.com")
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
|
|
||||||
result = userInGroup(service, []string{"group@example.com"}, "non-member-in-domain@example.com")
|
result = userInGroup(service, "group@example.com", "non-member-in-domain@example.com")
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
|
|
||||||
result = userInGroup(service, []string{"group@example.com"}, "non-member-out-of-domain@otherexample.com")
|
result = userInGroup(service, "group@example.com", "non-member-out-of-domain@otherexample.com")
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
}
|
}
|
||||||
|
@ -104,6 +104,12 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Authorize performs global authorization on an authenticated session.
|
||||||
|
// This is not used for fine-grained per route authorization rules.
|
||||||
|
func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
|
func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
|
||||||
return validateToken(ctx, p, s.AccessToken, nil)
|
return validateToken(ctx, p, s.AccessToken, nil)
|
||||||
|
@ -13,8 +13,8 @@ type Provider interface {
|
|||||||
// DEPRECATED: Migrate to EnrichSessionState
|
// DEPRECATED: Migrate to EnrichSessionState
|
||||||
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
||||||
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
||||||
ValidateGroup(string) bool
|
|
||||||
EnrichSessionState(ctx context.Context, s *sessions.SessionState) error
|
EnrichSessionState(ctx context.Context, s *sessions.SessionState) error
|
||||||
|
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||||
ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool
|
ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool
|
||||||
GetLoginURL(redirectURI, finalRedirect string) string
|
GetLoginURL(redirectURI, finalRedirect string) string
|
||||||
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||||
|
Reference in New Issue
Block a user