diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c44f180..36fae3e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication. - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped. +- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this + - Either `--google-group` or the new `--allowed-group` will work for Google now (`--google-group` will be used if both are set) + - Group membership lists will be passed to the backend with the `X-Forwarded-Groups` header + - If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately. + - Previously, group membership was only checked on session creation and refresh. - [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) `--skip-auth-route` is (almost) backwards compatible with `--skip-auth-regex` - We are marking `--skip-auth-regex` as DEPRECATED and will remove it in the next major version. - If your regex contains an `=` and you want it for all methods, you will need to add a leading `=` (this is the area where `--skip-auth-regex` doesn't port perfectly) @@ -18,6 +23,9 @@ ## Breaking Changes - [#911](https://github.com/oauth2-proxy/oauth2-rpoxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google". +- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow + - If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately. + - Previously, group membership was only checked on session creation and refresh. - [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass - [#800](https://github.com/oauth2-proxy/oauth2-proxy/pull/800) Fix import path for v7. The import path has changed to support the go get installation. - You can now `go get github.com/oauth2-proxy/oauth2-proxy/v7` to get the latest `v7` version of OAuth2 Proxy @@ -40,6 +48,7 @@ - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves) - [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves) - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) Integrate new header injectors into project (@JoelSpeed) +- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Create universal Authorization behavior across providers (@NickMeves) - [#898](https://github.com/oauth2-proxy/oauth2-proxy/pull/898) Migrate documentation to Docusaurus (@JoelSpeed) - [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock) - [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed) diff --git a/oauthproxy.go b/oauthproxy.go index 30b79dee..343c6ec9 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -42,6 +42,9 @@ var ( // ErrNeedsLogin means the user should be redirected to the login page ErrNeedsLogin = errors.New("redirect to login page") + // ErrAccessDenied means the user should receive a 401 Unauthorized response + ErrAccessDenied = errors.New("access denied") + // Used to check final redirects are not susceptible to open redirects. // Matches //, /\ and both of these with whitespace in between (eg / / or / \). invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`) @@ -105,7 +108,6 @@ type OAuthProxy struct { trustedIPs *ip.NetSet Banner string Footer string - AllowedGroups []string sessionChain alice.Chain headersChain alice.Chain @@ -219,7 +221,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr Banner: opts.Banner, Footer: opts.Footer, SignInMessage: buildSignInMessage(opts), - AllowedGroups: opts.AllowedGroups, basicAuthValidator: basicAuthValidator, displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, @@ -396,7 +397,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string { func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { if code == "" { - return nil, errors.New("missing code") + return nil, providers.ErrMissingCode } redirectURI := p.GetRedirectURI(host) s, err := p.provider.Redeem(ctx, redirectURI, code) @@ -909,11 +910,15 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { } // set cookie, or deny - if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { + authorized, err := p.provider.Authorize(req.Context(), session) + if err != nil { + logger.Errorf("Error with authorization: %v", err) + } + if p.Validator(session.Email) && authorized { logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) err := p.SaveSession(rw, req, session) if err != nil { - logger.Printf("Error saving session state for %s: %v", remoteAddr, err) + logger.Errorf("Error saving session state for %s: %v", remoteAddr, err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } @@ -967,6 +972,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { p.SignInPage(rw, req, http.StatusForbidden) } + case ErrAccessDenied: + p.ErrorPage(rw, http.StatusUnauthorized, "Permission Denied", "Unauthorized") + default: // unknown error logger.Errorf("Unexpected internal error: %v", err) @@ -977,7 +985,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { } // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so -// Returns nil, ErrNeedsLogin if user needs to login. +// Returns: +// - `nil, ErrNeedsLogin` if user needs to login. +// - `nil, ErrAccessDenied` if the authenticated user is not authorized // Set-Cookie headers may be set on the response as a side-effect of calling this method. func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { var session *sessionsapi.SessionState @@ -991,17 +1001,20 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R return nil, ErrNeedsLogin } - invalidEmail := session != nil && session.Email != "" && !p.Validator(session.Email) - invalidGroups := session != nil && !p.validateGroups(session.Groups) + invalidEmail := session.Email != "" && !p.Validator(session.Email) + authorized, err := p.provider.Authorize(req.Context(), session) + if err != nil { + logger.Errorf("Error with authorization: %v", err) + } - if invalidEmail || invalidGroups { - logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) + if invalidEmail || !authorized { + logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authorization via session: removing session %s", session) // Invalid session, clear it err := p.ClearSessionCookie(rw, req) if err != nil { - logger.Printf("Error clearing session cookie: %v", err) + logger.Errorf("Error clearing session cookie: %v", err) } - return nil, ErrNeedsLogin + return nil, ErrAccessDenied } return session, nil @@ -1033,23 +1046,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { rw.Header().Set("Content-Type", applicationJSON) rw.WriteHeader(code) } - -func (p *OAuthProxy) validateGroups(groups []string) bool { - if len(p.AllowedGroups) == 0 { - return true - } - - allowedGroups := map[string]struct{}{} - - for _, group := range p.AllowedGroups { - allowedGroups[group] = struct{}{} - } - - for _, group := range groups { - if _, ok := allowedGroups[group]; ok { - return true - } - } - - return false -} diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 1736b39f..a2733f6d 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi return nil, err } pcTest.proxy.provider = &TestProvider{ - ValidToken: opts.providerValidateCookieResponse, + ProviderData: &providers.ProviderData{}, + ValidToken: opts.providerValidateCookieResponse, } + pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups) // Now, zero-out proxy.CookieRefresh for the cases that don't involve // access_token validation. @@ -1284,7 +1286,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ - ValidToken: true, + ProviderData: &providers.ProviderData{}, + ValidToken: true, } pcTest.validateUser = true @@ -1376,7 +1379,8 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ - ValidToken: true, + ProviderData: &providers.ProviderData{}, + ValidToken: true, } pcTest.validateUser = true @@ -1455,7 +1459,8 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ - ValidToken: true, + ProviderData: &providers.ProviderData{}, + ValidToken: true, } pcTest.validateUser = true diff --git a/pkg/validation/options.go b/pkg/validation/options.go index fffd94ae..839c2035 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) + p.SetAllowedGroups(o.AllowedGroups) + provider := providers.New(o.ProviderType, p) if provider == nil { msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType)) @@ -255,7 +257,13 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { if err != nil { msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON) } else { - p.SetGroupRestriction(o.GoogleGroups, o.GoogleAdminEmail, file) + groups := o.AllowedGroups + // Backwards compatibility with `--google-group` option + if len(o.GoogleGroups) > 0 { + groups = o.GoogleGroups + p.SetAllowedGroups(groups) + } + p.SetGroupRestriction(groups, o.GoogleAdminEmail, file) } } case *providers.BitbucketProvider: diff --git a/providers/azure.go b/providers/azure.go index d65b11f4..e72f1068 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { if code == "" { - err = errors.New("missing code") - return + return nil, ErrMissingCode } clientSecret, err := p.GetClientSecret() if err != nil { - return + return nil, err } params := url.Values{} @@ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s created := time.Now() expires := time.Unix(jsonResponse.ExpiresOn, 0) - s = &sessions.SessionState{ + + return &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, CreatedAt: &created, ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, - } - return - + }, nil } // RefreshSessionIfNeeded checks if the session has expired and uses the diff --git a/providers/google.go b/providers/google.go index 97d1312e..36e84885 100644 --- a/providers/google.go +++ b/providers/google.go @@ -25,10 +25,16 @@ import ( // GoogleProvider represents an Google based Identity Provider type GoogleProvider struct { *ProviderData + RedeemRefreshURL *url.URL - // GroupValidator is a function that determines if the passed email is in - // the configured Google group. - GroupValidator func(string) bool + + // groupValidator is a function that determines if the user in the passed + // session is a member of any of the configured Google groups. + // + // This hits the Google API for each group, so it is called on Redeem & + // Refresh. `Authorize` uses the results of this saved in `session.Groups` + // Since it is called on every request. + groupValidator func(*sessions.SessionState) bool } var _ Provider = (*GoogleProvider)(nil) @@ -84,9 +90,9 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { }) return &GoogleProvider{ ProviderData: p, - // Set a default GroupValidator to just always return valid (true), it will + // Set a default groupValidator to just always return valid (true), it will // be overwritten if we configured a Google group restriction. - GroupValidator: func(email string) bool { + groupValidator: func(*sessions.SessionState) bool { return true }, } @@ -118,14 +124,13 @@ func claimsFromIDToken(idToken string) (*claims, error) { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { if code == "" { - err = errors.New("missing code") - return + return nil, ErrMissingCode } clientSecret, err := p.GetClientSecret() if err != nil { - return + return nil, err } params := url.Values{} @@ -155,12 +160,13 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( c, err := claimsFromIDToken(jsonResponse.IDToken) if err != nil { - return + return nil, err } created := time.Now() expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) - s = &sessions.SessionState{ + + return &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, CreatedAt: &created, @@ -168,18 +174,40 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( RefreshToken: jsonResponse.RefreshToken, Email: c.Email, User: c.Subject, - } - return + }, nil +} + +// EnrichSessionState checks the listed Google Groups configured and adds any +// that the user is a member of to session.Groups. +func (p *GoogleProvider) EnrichSessionState(ctx context.Context, s *sessions.SessionState) error { + // TODO (@NickMeves) - Move to pure EnrichSessionState logic and stop + // reusing legacy `groupValidator`. + // + // This is called here to get the validator to do the `session.Groups` + // populating logic. + p.groupValidator(s) + + return nil } // SetGroupRestriction configures the GoogleProvider to restrict access to the // specified group(s). AdminEmail has to be an administrative email on the domain that is // checked. CredentialsFile is the path to a json file containing a Google service // account credentials. +// +// TODO (@NickMeves) - Unit Test this OR refactor away from groupValidator func func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { adminService := getAdminService(adminEmail, credentialsReader) - p.GroupValidator = func(email string) bool { - return userInGroup(adminService, groups, email) + p.groupValidator = func(s *sessions.SessionState) bool { + // Reset our saved Groups in case membership changed + // This is used by `Authorize` on every request + s.Groups = make([]string, 0, len(groups)) + for _, group := range groups { + if userInGroup(adminService, group, s.Email) { + s.Groups = append(s.Groups, group) + } + } + return len(s.Groups) > 0 } } @@ -203,52 +231,41 @@ func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Serv return adminService } -func userInGroup(service *admin.Service, groups []string, email string) bool { - for _, group := range groups { - // Use the HasMember API to checking for the user's presence in each group or nested subgroups - req := service.Members.HasMember(group, email) +func userInGroup(service *admin.Service, group string, email string) bool { + // Use the HasMember API to checking for the user's presence in each group or nested subgroups + req := service.Members.HasMember(group, email) + r, err := req.Do() + if err == nil { + return r.IsMember + } + + gerr, ok := err.(*googleapi.Error) + switch { + case ok && gerr.Code == 404: + logger.Errorf("error checking membership in group %s: group does not exist", group) + case ok && gerr.Code == 400: + // It is possible for Members.HasMember to return false even if the email is a group member. + // One case that can cause this is if the user email is from a different domain than the group, + // e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error + // from the HasMember API. In that case, attempt to query the member object directly from the group. + req := service.Members.Get(group, email) r, err := req.Do() if err != nil { - gerr, ok := err.(*googleapi.Error) - switch { - case ok && gerr.Code == 404: - logger.Errorf("error checking membership in group %s: group does not exist", group) - case ok && gerr.Code == 400: - // It is possible for Members.HasMember to return false even if the email is a group member. - // One case that can cause this is if the user email is from a different domain than the group, - // e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error - // from the HasMember API. In that case, attempt to query the member object directly from the group. - req := service.Members.Get(group, email) - r, err := req.Do() - - if err != nil { - logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group) - continue - } - - // If the non-domain user is found within the group, still verify that they are "ACTIVE". - // Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN"). - if r.Status == "ACTIVE" { - return true - } - default: - logger.Errorf("error checking group membership: %v", err) - } - continue + logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group) + return false } - if r.IsMember { + + // If the non-domain user is found within the group, still verify that they are "ACTIVE". + // Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN"). + if r.Status == "ACTIVE" { return true } + default: + logger.Errorf("error checking group membership: %v", err) } return false } -// ValidateGroup validates that the provided email exists in the configured Google -// group(s). -func (p *GoogleProvider) ValidateGroup(email string) bool { - return p.GroupValidator(email) -} - // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { @@ -261,8 +278,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions return false, err } + // TODO (@NickMeves) - Align Group authorization needs with other providers' + // behavior in the `RefreshSession` case. + // // re-check that the user is in the proper google group(s) - if !p.ValidateGroup(s.Email) { + if !p.groupValidator(s) { return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) } diff --git a/providers/google_test.go b/providers/google_test.go index 35fc7f49..458439d6 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -10,6 +10,7 @@ import ( "net/url" "testing" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" admin "google.golang.org/api/admin/directory/v1" @@ -109,21 +110,50 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) { assert.Equal(t, "refresh12345", session.RefreshToken) } -func TestGoogleProviderValidateGroup(t *testing.T) { - p := newGoogleProvider() - p.GroupValidator = func(email string) bool { - return email == "michael.bland@gsa.gov" - } - assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) - p.GroupValidator = func(email string) bool { - return email != "michael.bland@gsa.gov" - } - assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov")) -} +func TestGoogleProviderGroupValidator(t *testing.T) { + const sessionEmail = "michael.bland@gsa.gov" -func TestGoogleProviderWithoutValidateGroup(t *testing.T) { - p := newGoogleProvider() - assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) + testCases := map[string]struct { + session *sessions.SessionState + validatorFunc func(*sessions.SessionState) bool + expectedAuthZ bool + }{ + "Email is authorized with groupValidator": { + session: &sessions.SessionState{ + Email: sessionEmail, + }, + validatorFunc: func(s *sessions.SessionState) bool { + return s.Email == sessionEmail + }, + expectedAuthZ: true, + }, + "Email is denied with groupValidator": { + session: &sessions.SessionState{ + Email: sessionEmail, + }, + validatorFunc: func(s *sessions.SessionState) bool { + return s.Email != sessionEmail + }, + expectedAuthZ: false, + }, + "Default does no authorization checks": { + session: &sessions.SessionState{ + Email: sessionEmail, + }, + validatorFunc: nil, + expectedAuthZ: true, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + g := NewWithT(t) + p := newGoogleProvider() + if tc.validatorFunc != nil { + p.groupValidator = tc.validatorFunc + } + g.Expect(p.groupValidator(tc.session)).To(Equal(tc.expectedAuthZ)) + }) + } } // @@ -196,7 +226,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { } -func TestGoogleProviderUserInGroup(t *testing.T) { +func TestGoogleProvider_userInGroup(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" { fmt.Fprintln(w, `{"isMember": true}`) @@ -233,18 +263,19 @@ func TestGoogleProviderUserInGroup(t *testing.T) { ctx := context.Background() service, err := admin.NewService(ctx, option.WithHTTPClient(client)) + assert.NoError(t, err) + service.BasePath = ts.URL - assert.Equal(t, nil, err) - result := userInGroup(service, []string{"group@example.com"}, "member-in-domain@example.com") + result := userInGroup(service, "group@example.com", "member-in-domain@example.com") assert.True(t, result) - result = userInGroup(service, []string{"group@example.com"}, "member-out-of-domain@otherexample.com") + result = userInGroup(service, "group@example.com", "member-out-of-domain@otherexample.com") assert.True(t, result) - result = userInGroup(service, []string{"group@example.com"}, "non-member-in-domain@example.com") + result = userInGroup(service, "group@example.com", "non-member-in-domain@example.com") assert.False(t, result) - result = userInGroup(service, []string{"group@example.com"}, "non-member-out-of-domain@otherexample.com") + result = userInGroup(service, "group@example.com", "non-member-out-of-domain@otherexample.com") assert.False(t, result) } diff --git a/providers/logingov.go b/providers/logingov.go index ff48ccc5..44d1cb46 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "crypto/rsa" - "errors" "fmt" "math/rand" "net/url" @@ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { if code == "" { - err = errors.New("missing code") - return + return nil, ErrMissingCode } claims := &jwt.StandardClaims{ @@ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) ss, err := token.SignedString(p.JWTKey) if err != nil { - return + return nil, err } params := url.Values{} @@ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) // check nonce here err = checkNonce(jsonResponse.IDToken, p) if err != nil { - return + return nil, err } // Get the email address var email string email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String()) if err != nil { - return + return nil, err } created := time.Now() expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) // Store the data that we found in the session state - s = &sessions.SessionState{ + return &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, CreatedAt: &created, ExpiresOn: &expires, Email: email, - } - return + }, nil } // GetLoginURL overrides GetLoginURL to add login.gov parameters diff --git a/providers/provider_data.go b/providers/provider_data.go index 5fce04ec..0881a1c6 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -26,6 +26,10 @@ type ProviderData struct { ClientSecretFile string Scope string Prompt string + + // Universal Group authorization data structure + // any provider can set to consume + AllowedGroups map[string]struct{} } // Data returns the ProviderData @@ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) { return string(fileClientSecret), nil } +// SetAllowedGroups organizes a group list into the AllowedGroups map +// to be consumed by Authorize implementations +func (p *ProviderData) SetAllowedGroups(groups []string) { + p.AllowedGroups = make(map[string]struct{}, len(groups)) + for _, group := range groups { + p.AllowedGroups[group] = struct{}{} + } +} + type providerDefaults struct { name string loginURL *url.URL diff --git a/providers/provider_default.go b/providers/provider_default.go index 479e52c0..00b70641 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -19,18 +19,21 @@ var ( // implementation method that doesn't have sensible defaults ErrNotImplemented = errors.New("not implemented") + // ErrMissingCode is returned when a Redeem method is called with an empty + // code + ErrMissingCode = errors.New("missing code") + _ Provider = (*ProviderData)(nil) ) // Redeem provides a default implementation of the OAuth2 token redemption process -func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { if code == "" { - err = errors.New("missing code") - return + return nil, ErrMissingCode } clientSecret, err := p.GetClientSecret() if err != nil { - return + return nil, err } params := url.Values{} @@ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s } err = result.UnmarshalInto(&jsonResponse) if err == nil { - s = &sessions.SessionState{ + return &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, - } - return + }, nil } - var v url.Values - v, err = url.ParseQuery(string(result.Body())) + values, err := url.ParseQuery(string(result.Body())) if err != nil { - return + return nil, err } - if a := v.Get("access_token"); a != "" { + if token := values.Get("access_token"); token != "" { created := time.Now() - s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} - } else { - err = fmt.Errorf("no access token found %s", result.Body()) + return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil } - return + + return nil, fmt.Errorf("no access token found %s", result.Body()) } // GetLoginURL with typical oauth parameters @@ -92,18 +92,28 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta return "", ErrNotImplemented } -// ValidateGroup validates that the provided email exists in the configured provider -// email group(s). -func (p *ProviderData) ValidateGroup(_ string) bool { - return true -} - // EnrichSessionState is called after Redeem to allow providers to enrich session fields // such as User, Email, Groups with provider specific API calls. func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error { return nil } +// Authorize performs global authorization on an authenticated session. +// This is not used for fine-grained per route authorization rules. +func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) { + if len(p.AllowedGroups) == 0 { + return true, nil + } + + for _, group := range s.Groups { + if _, ok := p.AllowedGroups[group]; ok { + return true, nil + } + } + + return false, nil +} + // ValidateSessionState validates the AccessToken func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { return validateToken(ctx, p, s.AccessToken, nil) diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index f04fe607..c9e87b33 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" ) @@ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) { s := &sessions.SessionState{} assert.NoError(t, p.EnrichSessionState(context.Background(), s)) } + +func TestProviderDataAuthorize(t *testing.T) { + testCases := []struct { + name string + allowedGroups []string + groups []string + expectedAuthZ bool + }{ + { + name: "NoAllowedGroups", + allowedGroups: []string{}, + groups: []string{}, + expectedAuthZ: true, + }, + { + name: "NoAllowedGroupsUserHasGroups", + allowedGroups: []string{}, + groups: []string{"foo", "bar"}, + expectedAuthZ: true, + }, + { + name: "UserInAllowedGroup", + allowedGroups: []string{"foo"}, + groups: []string{"foo", "bar"}, + expectedAuthZ: true, + }, + { + name: "UserNotInAllowedGroup", + allowedGroups: []string{"bar"}, + groups: []string{"baz", "foo"}, + expectedAuthZ: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + session := &sessions.SessionState{ + Groups: tc.groups, + } + p := &ProviderData{} + p.SetAllowedGroups(tc.allowedGroups) + + authorized, err := p.Authorize(context.Background(), session) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(authorized).To(Equal(tc.expectedAuthZ)) + }) + } +} diff --git a/providers/providers.go b/providers/providers.go index da707a56..50f4d6b2 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -13,8 +13,8 @@ type Provider interface { // DEPRECATED: Migrate to EnrichSessionState GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) - ValidateGroup(string) bool EnrichSessionState(ctx context.Context, s *sessions.SessionState) error + Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)