1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-21 21:47:11 +02:00

Merge pull request #797 from grnhse/refactor-provider-authz

Centralize Provider authorization interface method
This commit is contained in:
Joel Speed 2020-11-12 19:38:55 +00:00 committed by GitHub
commit c377466411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 285 additions and 150 deletions

View File

@ -6,6 +6,11 @@
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication. - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication.
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped. - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped.
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this
- Either `--google-group` or the new `--allowed-group` will work for Google now (`--google-group` will be used if both are set)
- Group membership lists will be passed to the backend with the `X-Forwarded-Groups` header
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
- Previously, group membership was only checked on session creation and refresh.
- [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) `--skip-auth-route` is (almost) backwards compatible with `--skip-auth-regex` - [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) `--skip-auth-route` is (almost) backwards compatible with `--skip-auth-regex`
- We are marking `--skip-auth-regex` as DEPRECATED and will remove it in the next major version. - We are marking `--skip-auth-regex` as DEPRECATED and will remove it in the next major version.
- If your regex contains an `=` and you want it for all methods, you will need to add a leading `=` (this is the area where `--skip-auth-regex` doesn't port perfectly) - If your regex contains an `=` and you want it for all methods, you will need to add a leading `=` (this is the area where `--skip-auth-regex` doesn't port perfectly)
@ -18,6 +23,9 @@
## Breaking Changes ## Breaking Changes
- [#911](https://github.com/oauth2-proxy/oauth2-rpoxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google". - [#911](https://github.com/oauth2-proxy/oauth2-rpoxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google".
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
- Previously, group membership was only checked on session creation and refresh.
- [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass - [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass
- [#800](https://github.com/oauth2-proxy/oauth2-proxy/pull/800) Fix import path for v7. The import path has changed to support the go get installation. - [#800](https://github.com/oauth2-proxy/oauth2-proxy/pull/800) Fix import path for v7. The import path has changed to support the go get installation.
- You can now `go get github.com/oauth2-proxy/oauth2-proxy/v7` to get the latest `v7` version of OAuth2 Proxy - You can now `go get github.com/oauth2-proxy/oauth2-proxy/v7` to get the latest `v7` version of OAuth2 Proxy
@ -40,6 +48,7 @@
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves) - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves)
- [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves) - [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves)
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) Integrate new header injectors into project (@JoelSpeed) - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) Integrate new header injectors into project (@JoelSpeed)
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Create universal Authorization behavior across providers (@NickMeves)
- [#898](https://github.com/oauth2-proxy/oauth2-proxy/pull/898) Migrate documentation to Docusaurus (@JoelSpeed) - [#898](https://github.com/oauth2-proxy/oauth2-proxy/pull/898) Migrate documentation to Docusaurus (@JoelSpeed)
- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock) - [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock)
- [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed) - [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed)

View File

@ -42,6 +42,9 @@ var (
// ErrNeedsLogin means the user should be redirected to the login page // ErrNeedsLogin means the user should be redirected to the login page
ErrNeedsLogin = errors.New("redirect to login page") ErrNeedsLogin = errors.New("redirect to login page")
// ErrAccessDenied means the user should receive a 401 Unauthorized response
ErrAccessDenied = errors.New("access denied")
// Used to check final redirects are not susceptible to open redirects. // Used to check final redirects are not susceptible to open redirects.
// Matches //, /\ and both of these with whitespace in between (eg / / or / \). // Matches //, /\ and both of these with whitespace in between (eg / / or / \).
invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`) invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`)
@ -105,7 +108,6 @@ type OAuthProxy struct {
trustedIPs *ip.NetSet trustedIPs *ip.NetSet
Banner string Banner string
Footer string Footer string
AllowedGroups []string
sessionChain alice.Chain sessionChain alice.Chain
headersChain alice.Chain headersChain alice.Chain
@ -219,7 +221,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
Banner: opts.Banner, Banner: opts.Banner,
Footer: opts.Footer, Footer: opts.Footer,
SignInMessage: buildSignInMessage(opts), SignInMessage: buildSignInMessage(opts),
AllowedGroups: opts.AllowedGroups,
basicAuthValidator: basicAuthValidator, basicAuthValidator: basicAuthValidator,
displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm,
@ -396,7 +397,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
if code == "" { if code == "" {
return nil, errors.New("missing code") return nil, providers.ErrMissingCode
} }
redirectURI := p.GetRedirectURI(host) redirectURI := p.GetRedirectURI(host)
s, err := p.provider.Redeem(ctx, redirectURI, code) s, err := p.provider.Redeem(ctx, redirectURI, code)
@ -909,11 +910,15 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
} }
// set cookie, or deny // set cookie, or deny
if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { authorized, err := p.provider.Authorize(req.Context(), session)
if err != nil {
logger.Errorf("Error with authorization: %v", err)
}
if p.Validator(session.Email) && authorized {
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
err := p.SaveSession(rw, req, session) err := p.SaveSession(rw, req, session)
if err != nil { if err != nil {
logger.Printf("Error saving session state for %s: %v", remoteAddr, err) logger.Errorf("Error saving session state for %s: %v", remoteAddr, err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
return return
} }
@ -967,6 +972,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
p.SignInPage(rw, req, http.StatusForbidden) p.SignInPage(rw, req, http.StatusForbidden)
} }
case ErrAccessDenied:
p.ErrorPage(rw, http.StatusUnauthorized, "Permission Denied", "Unauthorized")
default: default:
// unknown error // unknown error
logger.Errorf("Unexpected internal error: %v", err) logger.Errorf("Unexpected internal error: %v", err)
@ -977,7 +985,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
} }
// getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
// Returns nil, ErrNeedsLogin if user needs to login. // Returns:
// - `nil, ErrNeedsLogin` if user needs to login.
// - `nil, ErrAccessDenied` if the authenticated user is not authorized
// Set-Cookie headers may be set on the response as a side-effect of calling this method. // Set-Cookie headers may be set on the response as a side-effect of calling this method.
func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
var session *sessionsapi.SessionState var session *sessionsapi.SessionState
@ -991,17 +1001,20 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
return nil, ErrNeedsLogin return nil, ErrNeedsLogin
} }
invalidEmail := session != nil && session.Email != "" && !p.Validator(session.Email) invalidEmail := session.Email != "" && !p.Validator(session.Email)
invalidGroups := session != nil && !p.validateGroups(session.Groups) authorized, err := p.provider.Authorize(req.Context(), session)
if err != nil {
logger.Errorf("Error with authorization: %v", err)
}
if invalidEmail || invalidGroups { if invalidEmail || !authorized {
logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authorization via session: removing session %s", session)
// Invalid session, clear it // Invalid session, clear it
err := p.ClearSessionCookie(rw, req) err := p.ClearSessionCookie(rw, req)
if err != nil { if err != nil {
logger.Printf("Error clearing session cookie: %v", err) logger.Errorf("Error clearing session cookie: %v", err)
} }
return nil, ErrNeedsLogin return nil, ErrAccessDenied
} }
return session, nil return session, nil
@ -1033,23 +1046,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
rw.Header().Set("Content-Type", applicationJSON) rw.Header().Set("Content-Type", applicationJSON)
rw.WriteHeader(code) rw.WriteHeader(code)
} }
func (p *OAuthProxy) validateGroups(groups []string) bool {
if len(p.AllowedGroups) == 0 {
return true
}
allowedGroups := map[string]struct{}{}
for _, group := range p.AllowedGroups {
allowedGroups[group] = struct{}{}
}
for _, group := range groups {
if _, ok := allowedGroups[group]; ok {
return true
}
}
return false
}

View File

@ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
return nil, err return nil, err
} }
pcTest.proxy.provider = &TestProvider{ pcTest.proxy.provider = &TestProvider{
ValidToken: opts.providerValidateCookieResponse, ProviderData: &providers.ProviderData{},
ValidToken: opts.providerValidateCookieResponse,
} }
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups)
// Now, zero-out proxy.CookieRefresh for the cases that don't involve // Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation. // access_token validation.
@ -1284,7 +1286,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
pcTest.proxy.provider = &TestProvider{ pcTest.proxy.provider = &TestProvider{
ValidToken: true, ProviderData: &providers.ProviderData{},
ValidToken: true,
} }
pcTest.validateUser = true pcTest.validateUser = true
@ -1376,7 +1379,8 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
pcTest.proxy.provider = &TestProvider{ pcTest.proxy.provider = &TestProvider{
ValidToken: true, ProviderData: &providers.ProviderData{},
ValidToken: true,
} }
pcTest.validateUser = true pcTest.validateUser = true
@ -1455,7 +1459,8 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
pcTest.proxy.provider = &TestProvider{ pcTest.proxy.provider = &TestProvider{
ValidToken: true, ProviderData: &providers.ProviderData{},
ValidToken: true,
} }
pcTest.validateUser = true pcTest.validateUser = true

View File

@ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
p.SetAllowedGroups(o.AllowedGroups)
provider := providers.New(o.ProviderType, p) provider := providers.New(o.ProviderType, p)
if provider == nil { if provider == nil {
msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType)) msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType))
@ -255,7 +257,13 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
if err != nil { if err != nil {
msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON) msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON)
} else { } else {
p.SetGroupRestriction(o.GoogleGroups, o.GoogleAdminEmail, file) groups := o.AllowedGroups
// Backwards compatibility with `--google-group` option
if len(o.GoogleGroups) > 0 {
groups = o.GoogleGroups
p.SetAllowedGroups(groups)
}
p.SetGroupRestriction(groups, o.GoogleAdminEmail, file)
} }
} }
case *providers.BitbucketProvider: case *providers.BitbucketProvider:

View File

@ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" { if code == "" {
err = errors.New("missing code") return nil, ErrMissingCode
return
} }
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
return return nil, err
} }
params := url.Values{} params := url.Values{}
@ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
created := time.Now() created := time.Now()
expires := time.Unix(jsonResponse.ExpiresOn, 0) expires := time.Unix(jsonResponse.ExpiresOn, 0)
s = &sessions.SessionState{
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
CreatedAt: &created, CreatedAt: &created,
ExpiresOn: &expires, ExpiresOn: &expires,
RefreshToken: jsonResponse.RefreshToken, RefreshToken: jsonResponse.RefreshToken,
} }, nil
return
} }
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the

View File

@ -25,10 +25,16 @@ import (
// GoogleProvider represents an Google based Identity Provider // GoogleProvider represents an Google based Identity Provider
type GoogleProvider struct { type GoogleProvider struct {
*ProviderData *ProviderData
RedeemRefreshURL *url.URL RedeemRefreshURL *url.URL
// GroupValidator is a function that determines if the passed email is in
// the configured Google group. // groupValidator is a function that determines if the user in the passed
GroupValidator func(string) bool // session is a member of any of the configured Google groups.
//
// This hits the Google API for each group, so it is called on Redeem &
// Refresh. `Authorize` uses the results of this saved in `session.Groups`
// Since it is called on every request.
groupValidator func(*sessions.SessionState) bool
} }
var _ Provider = (*GoogleProvider)(nil) var _ Provider = (*GoogleProvider)(nil)
@ -84,9 +90,9 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
}) })
return &GoogleProvider{ return &GoogleProvider{
ProviderData: p, ProviderData: p,
// Set a default GroupValidator to just always return valid (true), it will // Set a default groupValidator to just always return valid (true), it will
// be overwritten if we configured a Google group restriction. // be overwritten if we configured a Google group restriction.
GroupValidator: func(email string) bool { groupValidator: func(*sessions.SessionState) bool {
return true return true
}, },
} }
@ -118,14 +124,13 @@ func claimsFromIDToken(idToken string) (*claims, error) {
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" { if code == "" {
err = errors.New("missing code") return nil, ErrMissingCode
return
} }
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
return return nil, err
} }
params := url.Values{} params := url.Values{}
@ -155,12 +160,13 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
c, err := claimsFromIDToken(jsonResponse.IDToken) c, err := claimsFromIDToken(jsonResponse.IDToken)
if err != nil { if err != nil {
return return nil, err
} }
created := time.Now() created := time.Now()
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
s = &sessions.SessionState{
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
CreatedAt: &created, CreatedAt: &created,
@ -168,18 +174,40 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
RefreshToken: jsonResponse.RefreshToken, RefreshToken: jsonResponse.RefreshToken,
Email: c.Email, Email: c.Email,
User: c.Subject, User: c.Subject,
} }, nil
return }
// EnrichSessionState checks the listed Google Groups configured and adds any
// that the user is a member of to session.Groups.
func (p *GoogleProvider) EnrichSessionState(ctx context.Context, s *sessions.SessionState) error {
// TODO (@NickMeves) - Move to pure EnrichSessionState logic and stop
// reusing legacy `groupValidator`.
//
// This is called here to get the validator to do the `session.Groups`
// populating logic.
p.groupValidator(s)
return nil
} }
// SetGroupRestriction configures the GoogleProvider to restrict access to the // SetGroupRestriction configures the GoogleProvider to restrict access to the
// specified group(s). AdminEmail has to be an administrative email on the domain that is // specified group(s). AdminEmail has to be an administrative email on the domain that is
// checked. CredentialsFile is the path to a json file containing a Google service // checked. CredentialsFile is the path to a json file containing a Google service
// account credentials. // account credentials.
//
// TODO (@NickMeves) - Unit Test this OR refactor away from groupValidator func
func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) {
adminService := getAdminService(adminEmail, credentialsReader) adminService := getAdminService(adminEmail, credentialsReader)
p.GroupValidator = func(email string) bool { p.groupValidator = func(s *sessions.SessionState) bool {
return userInGroup(adminService, groups, email) // Reset our saved Groups in case membership changed
// This is used by `Authorize` on every request
s.Groups = make([]string, 0, len(groups))
for _, group := range groups {
if userInGroup(adminService, group, s.Email) {
s.Groups = append(s.Groups, group)
}
}
return len(s.Groups) > 0
} }
} }
@ -203,52 +231,41 @@ func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Serv
return adminService return adminService
} }
func userInGroup(service *admin.Service, groups []string, email string) bool { func userInGroup(service *admin.Service, group string, email string) bool {
for _, group := range groups { // Use the HasMember API to checking for the user's presence in each group or nested subgroups
// Use the HasMember API to checking for the user's presence in each group or nested subgroups req := service.Members.HasMember(group, email)
req := service.Members.HasMember(group, email) r, err := req.Do()
if err == nil {
return r.IsMember
}
gerr, ok := err.(*googleapi.Error)
switch {
case ok && gerr.Code == 404:
logger.Errorf("error checking membership in group %s: group does not exist", group)
case ok && gerr.Code == 400:
// It is possible for Members.HasMember to return false even if the email is a group member.
// One case that can cause this is if the user email is from a different domain than the group,
// e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error
// from the HasMember API. In that case, attempt to query the member object directly from the group.
req := service.Members.Get(group, email)
r, err := req.Do() r, err := req.Do()
if err != nil { if err != nil {
gerr, ok := err.(*googleapi.Error) logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group)
switch { return false
case ok && gerr.Code == 404:
logger.Errorf("error checking membership in group %s: group does not exist", group)
case ok && gerr.Code == 400:
// It is possible for Members.HasMember to return false even if the email is a group member.
// One case that can cause this is if the user email is from a different domain than the group,
// e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error
// from the HasMember API. In that case, attempt to query the member object directly from the group.
req := service.Members.Get(group, email)
r, err := req.Do()
if err != nil {
logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group)
continue
}
// If the non-domain user is found within the group, still verify that they are "ACTIVE".
// Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN").
if r.Status == "ACTIVE" {
return true
}
default:
logger.Errorf("error checking group membership: %v", err)
}
continue
} }
if r.IsMember {
// If the non-domain user is found within the group, still verify that they are "ACTIVE".
// Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN").
if r.Status == "ACTIVE" {
return true return true
} }
default:
logger.Errorf("error checking group membership: %v", err)
} }
return false return false
} }
// ValidateGroup validates that the provided email exists in the configured Google
// group(s).
func (p *GoogleProvider) ValidateGroup(email string) bool {
return p.GroupValidator(email)
}
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
@ -261,8 +278,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
return false, err return false, err
} }
// TODO (@NickMeves) - Align Group authorization needs with other providers'
// behavior in the `RefreshSession` case.
//
// re-check that the user is in the proper google group(s) // re-check that the user is in the proper google group(s)
if !p.ValidateGroup(s.Email) { if !p.groupValidator(s) {
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
} }

View File

@ -10,6 +10,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
admin "google.golang.org/api/admin/directory/v1" admin "google.golang.org/api/admin/directory/v1"
@ -109,21 +110,50 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
assert.Equal(t, "refresh12345", session.RefreshToken) assert.Equal(t, "refresh12345", session.RefreshToken)
} }
func TestGoogleProviderValidateGroup(t *testing.T) { func TestGoogleProviderGroupValidator(t *testing.T) {
p := newGoogleProvider() const sessionEmail = "michael.bland@gsa.gov"
p.GroupValidator = func(email string) bool {
return email == "michael.bland@gsa.gov"
}
assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov"))
p.GroupValidator = func(email string) bool {
return email != "michael.bland@gsa.gov"
}
assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov"))
}
func TestGoogleProviderWithoutValidateGroup(t *testing.T) { testCases := map[string]struct {
p := newGoogleProvider() session *sessions.SessionState
assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) validatorFunc func(*sessions.SessionState) bool
expectedAuthZ bool
}{
"Email is authorized with groupValidator": {
session: &sessions.SessionState{
Email: sessionEmail,
},
validatorFunc: func(s *sessions.SessionState) bool {
return s.Email == sessionEmail
},
expectedAuthZ: true,
},
"Email is denied with groupValidator": {
session: &sessions.SessionState{
Email: sessionEmail,
},
validatorFunc: func(s *sessions.SessionState) bool {
return s.Email != sessionEmail
},
expectedAuthZ: false,
},
"Default does no authorization checks": {
session: &sessions.SessionState{
Email: sessionEmail,
},
validatorFunc: nil,
expectedAuthZ: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
g := NewWithT(t)
p := newGoogleProvider()
if tc.validatorFunc != nil {
p.groupValidator = tc.validatorFunc
}
g.Expect(p.groupValidator(tc.session)).To(Equal(tc.expectedAuthZ))
})
}
} }
// //
@ -196,7 +226,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
} }
func TestGoogleProviderUserInGroup(t *testing.T) { func TestGoogleProvider_userInGroup(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" { if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" {
fmt.Fprintln(w, `{"isMember": true}`) fmt.Fprintln(w, `{"isMember": true}`)
@ -233,18 +263,19 @@ func TestGoogleProviderUserInGroup(t *testing.T) {
ctx := context.Background() ctx := context.Background()
service, err := admin.NewService(ctx, option.WithHTTPClient(client)) service, err := admin.NewService(ctx, option.WithHTTPClient(client))
assert.NoError(t, err)
service.BasePath = ts.URL service.BasePath = ts.URL
assert.Equal(t, nil, err)
result := userInGroup(service, []string{"group@example.com"}, "member-in-domain@example.com") result := userInGroup(service, "group@example.com", "member-in-domain@example.com")
assert.True(t, result) assert.True(t, result)
result = userInGroup(service, []string{"group@example.com"}, "member-out-of-domain@otherexample.com") result = userInGroup(service, "group@example.com", "member-out-of-domain@otherexample.com")
assert.True(t, result) assert.True(t, result)
result = userInGroup(service, []string{"group@example.com"}, "non-member-in-domain@example.com") result = userInGroup(service, "group@example.com", "non-member-in-domain@example.com")
assert.False(t, result) assert.False(t, result)
result = userInGroup(service, []string{"group@example.com"}, "non-member-out-of-domain@otherexample.com") result = userInGroup(service, "group@example.com", "non-member-out-of-domain@otherexample.com")
assert.False(t, result) assert.False(t, result)
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/rsa" "crypto/rsa"
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net/url" "net/url"
@ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" { if code == "" {
err = errors.New("missing code") return nil, ErrMissingCode
return
} }
claims := &jwt.StandardClaims{ claims := &jwt.StandardClaims{
@ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
ss, err := token.SignedString(p.JWTKey) ss, err := token.SignedString(p.JWTKey)
if err != nil { if err != nil {
return return nil, err
} }
params := url.Values{} params := url.Values{}
@ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
// check nonce here // check nonce here
err = checkNonce(jsonResponse.IDToken, p) err = checkNonce(jsonResponse.IDToken, p)
if err != nil { if err != nil {
return return nil, err
} }
// Get the email address // Get the email address
var email string var email string
email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String()) email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
if err != nil { if err != nil {
return return nil, err
} }
created := time.Now() created := time.Now()
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
// Store the data that we found in the session state // Store the data that we found in the session state
s = &sessions.SessionState{ return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
CreatedAt: &created, CreatedAt: &created,
ExpiresOn: &expires, ExpiresOn: &expires,
Email: email, Email: email,
} }, nil
return
} }
// GetLoginURL overrides GetLoginURL to add login.gov parameters // GetLoginURL overrides GetLoginURL to add login.gov parameters

View File

@ -26,6 +26,10 @@ type ProviderData struct {
ClientSecretFile string ClientSecretFile string
Scope string Scope string
Prompt string Prompt string
// Universal Group authorization data structure
// any provider can set to consume
AllowedGroups map[string]struct{}
} }
// Data returns the ProviderData // Data returns the ProviderData
@ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) {
return string(fileClientSecret), nil return string(fileClientSecret), nil
} }
// SetAllowedGroups organizes a group list into the AllowedGroups map
// to be consumed by Authorize implementations
func (p *ProviderData) SetAllowedGroups(groups []string) {
p.AllowedGroups = make(map[string]struct{}, len(groups))
for _, group := range groups {
p.AllowedGroups[group] = struct{}{}
}
}
type providerDefaults struct { type providerDefaults struct {
name string name string
loginURL *url.URL loginURL *url.URL

View File

@ -19,18 +19,21 @@ var (
// implementation method that doesn't have sensible defaults // implementation method that doesn't have sensible defaults
ErrNotImplemented = errors.New("not implemented") ErrNotImplemented = errors.New("not implemented")
// ErrMissingCode is returned when a Redeem method is called with an empty
// code
ErrMissingCode = errors.New("missing code")
_ Provider = (*ProviderData)(nil) _ Provider = (*ProviderData)(nil)
) )
// Redeem provides a default implementation of the OAuth2 token redemption process // Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" { if code == "" {
err = errors.New("missing code") return nil, ErrMissingCode
return
} }
clientSecret, err := p.GetClientSecret() clientSecret, err := p.GetClientSecret()
if err != nil { if err != nil {
return return nil, err
} }
params := url.Values{} params := url.Values{}
@ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
} }
err = result.UnmarshalInto(&jsonResponse) err = result.UnmarshalInto(&jsonResponse)
if err == nil { if err == nil {
s = &sessions.SessionState{ return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
} }, nil
return
} }
var v url.Values values, err := url.ParseQuery(string(result.Body()))
v, err = url.ParseQuery(string(result.Body()))
if err != nil { if err != nil {
return return nil, err
} }
if a := v.Get("access_token"); a != "" { if token := values.Get("access_token"); token != "" {
created := time.Now() created := time.Now()
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil
} else {
err = fmt.Errorf("no access token found %s", result.Body())
} }
return
return nil, fmt.Errorf("no access token found %s", result.Body())
} }
// GetLoginURL with typical oauth parameters // GetLoginURL with typical oauth parameters
@ -92,18 +92,28 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta
return "", ErrNotImplemented return "", ErrNotImplemented
} }
// ValidateGroup validates that the provided email exists in the configured provider
// email group(s).
func (p *ProviderData) ValidateGroup(_ string) bool {
return true
}
// EnrichSessionState is called after Redeem to allow providers to enrich session fields // EnrichSessionState is called after Redeem to allow providers to enrich session fields
// such as User, Email, Groups with provider specific API calls. // such as User, Email, Groups with provider specific API calls.
func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error { func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error {
return nil return nil
} }
// Authorize performs global authorization on an authenticated session.
// This is not used for fine-grained per route authorization rules.
func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) {
if len(p.AllowedGroups) == 0 {
return true, nil
}
for _, group := range s.Groups {
if _, ok := p.AllowedGroups[group]; ok {
return true, nil
}
}
return false, nil
}
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, nil) return validateToken(ctx, p, s.AccessToken, nil)

View File

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) {
s := &sessions.SessionState{} s := &sessions.SessionState{}
assert.NoError(t, p.EnrichSessionState(context.Background(), s)) assert.NoError(t, p.EnrichSessionState(context.Background(), s))
} }
func TestProviderDataAuthorize(t *testing.T) {
testCases := []struct {
name string
allowedGroups []string
groups []string
expectedAuthZ bool
}{
{
name: "NoAllowedGroups",
allowedGroups: []string{},
groups: []string{},
expectedAuthZ: true,
},
{
name: "NoAllowedGroupsUserHasGroups",
allowedGroups: []string{},
groups: []string{"foo", "bar"},
expectedAuthZ: true,
},
{
name: "UserInAllowedGroup",
allowedGroups: []string{"foo"},
groups: []string{"foo", "bar"},
expectedAuthZ: true,
},
{
name: "UserNotInAllowedGroup",
allowedGroups: []string{"bar"},
groups: []string{"baz", "foo"},
expectedAuthZ: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
g := NewWithT(t)
session := &sessions.SessionState{
Groups: tc.groups,
}
p := &ProviderData{}
p.SetAllowedGroups(tc.allowedGroups)
authorized, err := p.Authorize(context.Background(), session)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(authorized).To(Equal(tc.expectedAuthZ))
})
}
}

View File

@ -13,8 +13,8 @@ type Provider interface {
// DEPRECATED: Migrate to EnrichSessionState // DEPRECATED: Migrate to EnrichSessionState
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
ValidateGroup(string) bool
EnrichSessionState(ctx context.Context, s *sessions.SessionState) error EnrichSessionState(ctx context.Context, s *sessions.SessionState) error
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)