diff --git a/oauthproxy.go b/oauthproxy.go index 30b79dee..383b3bad 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -909,7 +909,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { } // set cookie, or deny - if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { + if p.Validator(session.Email) { logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) err := p.SaveSession(rw, req, session) if err != nil { @@ -991,15 +991,19 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R return nil, ErrNeedsLogin } - invalidEmail := session != nil && session.Email != "" && !p.Validator(session.Email) + invalidEmail := session.Email != "" && !p.Validator(session.Email) invalidGroups := session != nil && !p.validateGroups(session.Groups) + authorized, err := p.provider.Authorize(req.Context(), session) + if err != nil { + logger.Errorf("Error with authorization: %v", err) + } - if invalidEmail || invalidGroups { - logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) + if invalidEmail || invalidGroups || !authorized { + logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) // Invalid session, clear it err := p.ClearSessionCookie(rw, req) if err != nil { - logger.Printf("Error clearing session cookie: %v", err) + logger.Errorf("Error clearing session cookie: %v", err) } return nil, ErrNeedsLogin } diff --git a/providers/google.go b/providers/google.go index 97d1312e..97ea52d7 100644 --- a/providers/google.go +++ b/providers/google.go @@ -25,10 +25,13 @@ import ( // GoogleProvider represents an Google based Identity Provider type GoogleProvider struct { *ProviderData + RedeemRefreshURL *url.URL - // GroupValidator is a function that determines if the passed email is in - // the configured Google group. - GroupValidator func(string) bool + // GroupValidator is a function that determines if the user in the passed + // session is a member of any of the configured Google groups. + GroupValidator func(*sessions.SessionState, bool) bool + + allowedGroups map[string]struct{} } var _ Provider = (*GoogleProvider)(nil) @@ -86,7 +89,7 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { ProviderData: p, // Set a default GroupValidator to just always return valid (true), it will // be overwritten if we configured a Google group restriction. - GroupValidator: func(email string) bool { + GroupValidator: func(*sessions.SessionState, bool) bool { return true }, } @@ -118,14 +121,14 @@ func claimsFromIDToken(idToken string) (*claims, error) { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { if code == "" { - err = errors.New("missing code") - return + err := errors.New("missing code") + return nil, err } clientSecret, err := p.GetClientSecret() if err != nil { - return + return nil, err } params := url.Values{} @@ -155,12 +158,12 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( c, err := claimsFromIDToken(jsonResponse.IDToken) if err != nil { - return + return nil, err } created := time.Now() expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) - s = &sessions.SessionState{ + s := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, CreatedAt: &created, @@ -169,7 +172,13 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( Email: c.Email, User: c.Subject, } - return + p.GroupValidator(s, true) + + return s, nil +} + +func (p *GoogleProvider) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) { + return p.GroupValidator(s, false), nil } // SetGroupRestriction configures the GoogleProvider to restrict access to the @@ -178,8 +187,30 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( // account credentials. func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { adminService := getAdminService(adminEmail, credentialsReader) - p.GroupValidator = func(email string) bool { - return userInGroup(adminService, groups, email) + for _, group := range groups { + p.allowedGroups[group] = struct{}{} + } + + p.GroupValidator = func(s *sessions.SessionState, sync bool) bool { + if sync { + // Reset our saved Groups in case membership changed + s.Groups = make([]string, 0, len(groups)) + for _, group := range groups { + if userInGroup(adminService, group, s.Email) { + s.Groups = append(s.Groups, group) + } + } + return len(s.Groups) > 0 + } + + // Don't resync with Google, handles when OAuth2-Proxy settings + // alter allowed groups but existing sessions are still valid + for _, group := range s.Groups { + if _, ok := p.allowedGroups[group]; ok { + return true + } + } + return false } } @@ -203,52 +234,41 @@ func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Serv return adminService } -func userInGroup(service *admin.Service, groups []string, email string) bool { - for _, group := range groups { - // Use the HasMember API to checking for the user's presence in each group or nested subgroups - req := service.Members.HasMember(group, email) +func userInGroup(service *admin.Service, group string, email string) bool { + // Use the HasMember API to checking for the user's presence in each group or nested subgroups + req := service.Members.HasMember(group, email) + r, err := req.Do() + if err == nil { + return r.IsMember + } + + gerr, ok := err.(*googleapi.Error) + switch { + case ok && gerr.Code == 404: + logger.Errorf("error checking membership in group %s: group does not exist", group) + case ok && gerr.Code == 400: + // It is possible for Members.HasMember to return false even if the email is a group member. + // One case that can cause this is if the user email is from a different domain than the group, + // e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error + // from the HasMember API. In that case, attempt to query the member object directly from the group. + req := service.Members.Get(group, email) r, err := req.Do() if err != nil { - gerr, ok := err.(*googleapi.Error) - switch { - case ok && gerr.Code == 404: - logger.Errorf("error checking membership in group %s: group does not exist", group) - case ok && gerr.Code == 400: - // It is possible for Members.HasMember to return false even if the email is a group member. - // One case that can cause this is if the user email is from a different domain than the group, - // e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error - // from the HasMember API. In that case, attempt to query the member object directly from the group. - req := service.Members.Get(group, email) - r, err := req.Do() - - if err != nil { - logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group) - continue - } - - // If the non-domain user is found within the group, still verify that they are "ACTIVE". - // Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN"). - if r.Status == "ACTIVE" { - return true - } - default: - logger.Errorf("error checking group membership: %v", err) - } - continue + logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group) + return false } - if r.IsMember { + + // If the non-domain user is found within the group, still verify that they are "ACTIVE". + // Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN"). + if r.Status == "ACTIVE" { return true } + default: + logger.Errorf("error checking group membership: %v", err) } return false } -// ValidateGroup validates that the provided email exists in the configured Google -// group(s). -func (p *GoogleProvider) ValidateGroup(email string) bool { - return p.GroupValidator(email) -} - // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { @@ -262,7 +282,7 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions } // re-check that the user is in the proper google group(s) - if !p.ValidateGroup(s.Email) { + if !p.GroupValidator(s, true) { return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) } diff --git a/providers/google_test.go b/providers/google_test.go index 35fc7f49..5e678f22 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -10,6 +10,7 @@ import ( "net/url" "testing" + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" admin "google.golang.org/api/admin/directory/v1" @@ -109,21 +110,52 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) { assert.Equal(t, "refresh12345", session.RefreshToken) } -func TestGoogleProviderValidateGroup(t *testing.T) { - p := newGoogleProvider() - p.GroupValidator = func(email string) bool { - return email == "michael.bland@gsa.gov" - } - assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) - p.GroupValidator = func(email string) bool { - return email != "michael.bland@gsa.gov" - } - assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov")) -} +func TestGoogleProviderAuthorize(t *testing.T) { + const sessionEmail = "michael.bland@gsa.gov" -func TestGoogleProviderWithoutValidateGroup(t *testing.T) { - p := newGoogleProvider() - assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) + testCases := map[string]struct { + session *sessions.SessionState + validatorFunc func(*sessions.SessionState, bool) bool + expectedAuthZ bool + }{ + "Email is authorized with GroupValidator": { + session: &sessions.SessionState{ + Email: sessionEmail, + }, + validatorFunc: func(s *sessions.SessionState, _ bool) bool { + return s.Email == sessionEmail + }, + expectedAuthZ: true, + }, + "Email is denied with GroupValidator": { + session: &sessions.SessionState{ + Email: sessionEmail, + }, + validatorFunc: func(s *sessions.SessionState, _ bool) bool { + return s.Email != sessionEmail + }, + expectedAuthZ: false, + }, + "Default does no authorization checks": { + session: &sessions.SessionState{ + Email: sessionEmail, + }, + validatorFunc: nil, + expectedAuthZ: true, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + g := NewWithT(t) + p := newGoogleProvider() + if tc.validatorFunc != nil { + p.GroupValidator = tc.validatorFunc + } + authorized, err := p.Authorize(context.Background(), tc.session) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(authorized).To(Equal(tc.expectedAuthZ)) + }) + } } // @@ -196,7 +228,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { } -func TestGoogleProviderUserInGroup(t *testing.T) { +func TestGoogleProvider_userInGroup(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" { fmt.Fprintln(w, `{"isMember": true}`) @@ -233,18 +265,19 @@ func TestGoogleProviderUserInGroup(t *testing.T) { ctx := context.Background() service, err := admin.NewService(ctx, option.WithHTTPClient(client)) + assert.NoError(t, err) + service.BasePath = ts.URL - assert.Equal(t, nil, err) - result := userInGroup(service, []string{"group@example.com"}, "member-in-domain@example.com") + result := userInGroup(service, "group@example.com", "member-in-domain@example.com") assert.True(t, result) - result = userInGroup(service, []string{"group@example.com"}, "member-out-of-domain@otherexample.com") + result = userInGroup(service, "group@example.com", "member-out-of-domain@otherexample.com") assert.True(t, result) - result = userInGroup(service, []string{"group@example.com"}, "non-member-in-domain@example.com") + result = userInGroup(service, "group@example.com", "non-member-in-domain@example.com") assert.False(t, result) - result = userInGroup(service, []string{"group@example.com"}, "non-member-out-of-domain@otherexample.com") + result = userInGroup(service, "group@example.com", "non-member-out-of-domain@otherexample.com") assert.False(t, result) } diff --git a/providers/provider_default.go b/providers/provider_default.go index 479e52c0..6f3892e5 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -104,6 +104,12 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session return nil } +// Authorize performs global authorization on an authenticated session. +// This is not used for fine-grained per route authorization rules. +func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) { + return true, nil +} + // ValidateSessionState validates the AccessToken func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { return validateToken(ctx, p, s.AccessToken, nil) diff --git a/providers/providers.go b/providers/providers.go index da707a56..50f4d6b2 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -13,8 +13,8 @@ type Provider interface { // DEPRECATED: Migrate to EnrichSessionState GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) - ValidateGroup(string) bool EnrichSessionState(ctx context.Context, s *sessions.SessionState) error + Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)