mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-19 21:27:58 +02:00
Merge pull request #797 from grnhse/refactor-provider-authz
Centralize Provider authorization interface method
This commit is contained in:
commit
c377466411
@ -6,6 +6,11 @@
|
||||
|
||||
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication.
|
||||
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped.
|
||||
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this
|
||||
- Either `--google-group` or the new `--allowed-group` will work for Google now (`--google-group` will be used if both are set)
|
||||
- Group membership lists will be passed to the backend with the `X-Forwarded-Groups` header
|
||||
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
|
||||
- Previously, group membership was only checked on session creation and refresh.
|
||||
- [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) `--skip-auth-route` is (almost) backwards compatible with `--skip-auth-regex`
|
||||
- We are marking `--skip-auth-regex` as DEPRECATED and will remove it in the next major version.
|
||||
- If your regex contains an `=` and you want it for all methods, you will need to add a leading `=` (this is the area where `--skip-auth-regex` doesn't port perfectly)
|
||||
@ -18,6 +23,9 @@
|
||||
## Breaking Changes
|
||||
|
||||
- [#911](https://github.com/oauth2-proxy/oauth2-rpoxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google".
|
||||
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow
|
||||
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
|
||||
- Previously, group membership was only checked on session creation and refresh.
|
||||
- [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass
|
||||
- [#800](https://github.com/oauth2-proxy/oauth2-proxy/pull/800) Fix import path for v7. The import path has changed to support the go get installation.
|
||||
- You can now `go get github.com/oauth2-proxy/oauth2-proxy/v7` to get the latest `v7` version of OAuth2 Proxy
|
||||
@ -40,6 +48,7 @@
|
||||
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves)
|
||||
- [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves)
|
||||
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) Integrate new header injectors into project (@JoelSpeed)
|
||||
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Create universal Authorization behavior across providers (@NickMeves)
|
||||
- [#898](https://github.com/oauth2-proxy/oauth2-proxy/pull/898) Migrate documentation to Docusaurus (@JoelSpeed)
|
||||
- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock)
|
||||
- [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed)
|
||||
|
@ -42,6 +42,9 @@ var (
|
||||
// ErrNeedsLogin means the user should be redirected to the login page
|
||||
ErrNeedsLogin = errors.New("redirect to login page")
|
||||
|
||||
// ErrAccessDenied means the user should receive a 401 Unauthorized response
|
||||
ErrAccessDenied = errors.New("access denied")
|
||||
|
||||
// Used to check final redirects are not susceptible to open redirects.
|
||||
// Matches //, /\ and both of these with whitespace in between (eg / / or / \).
|
||||
invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`)
|
||||
@ -105,7 +108,6 @@ type OAuthProxy struct {
|
||||
trustedIPs *ip.NetSet
|
||||
Banner string
|
||||
Footer string
|
||||
AllowedGroups []string
|
||||
|
||||
sessionChain alice.Chain
|
||||
headersChain alice.Chain
|
||||
@ -219,7 +221,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
Banner: opts.Banner,
|
||||
Footer: opts.Footer,
|
||||
SignInMessage: buildSignInMessage(opts),
|
||||
AllowedGroups: opts.AllowedGroups,
|
||||
|
||||
basicAuthValidator: basicAuthValidator,
|
||||
displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm,
|
||||
@ -396,7 +397,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
|
||||
|
||||
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
return nil, providers.ErrMissingCode
|
||||
}
|
||||
redirectURI := p.GetRedirectURI(host)
|
||||
s, err := p.provider.Redeem(ctx, redirectURI, code)
|
||||
@ -909,11 +910,15 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
// set cookie, or deny
|
||||
if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) {
|
||||
authorized, err := p.provider.Authorize(req.Context(), session)
|
||||
if err != nil {
|
||||
logger.Errorf("Error with authorization: %v", err)
|
||||
}
|
||||
if p.Validator(session.Email) && authorized {
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
|
||||
err := p.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
logger.Printf("Error saving session state for %s: %v", remoteAddr, err)
|
||||
logger.Errorf("Error saving session state for %s: %v", remoteAddr, err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||
return
|
||||
}
|
||||
@ -967,6 +972,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
||||
p.SignInPage(rw, req, http.StatusForbidden)
|
||||
}
|
||||
|
||||
case ErrAccessDenied:
|
||||
p.ErrorPage(rw, http.StatusUnauthorized, "Permission Denied", "Unauthorized")
|
||||
|
||||
default:
|
||||
// unknown error
|
||||
logger.Errorf("Unexpected internal error: %v", err)
|
||||
@ -977,7 +985,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
// getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
|
||||
// Returns nil, ErrNeedsLogin if user needs to login.
|
||||
// Returns:
|
||||
// - `nil, ErrNeedsLogin` if user needs to login.
|
||||
// - `nil, ErrAccessDenied` if the authenticated user is not authorized
|
||||
// Set-Cookie headers may be set on the response as a side-effect of calling this method.
|
||||
func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
|
||||
var session *sessionsapi.SessionState
|
||||
@ -991,17 +1001,20 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
|
||||
return nil, ErrNeedsLogin
|
||||
}
|
||||
|
||||
invalidEmail := session != nil && session.Email != "" && !p.Validator(session.Email)
|
||||
invalidGroups := session != nil && !p.validateGroups(session.Groups)
|
||||
invalidEmail := session.Email != "" && !p.Validator(session.Email)
|
||||
authorized, err := p.provider.Authorize(req.Context(), session)
|
||||
if err != nil {
|
||||
logger.Errorf("Error with authorization: %v", err)
|
||||
}
|
||||
|
||||
if invalidEmail || invalidGroups {
|
||||
logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
|
||||
if invalidEmail || !authorized {
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authorization via session: removing session %s", session)
|
||||
// Invalid session, clear it
|
||||
err := p.ClearSessionCookie(rw, req)
|
||||
if err != nil {
|
||||
logger.Printf("Error clearing session cookie: %v", err)
|
||||
logger.Errorf("Error clearing session cookie: %v", err)
|
||||
}
|
||||
return nil, ErrNeedsLogin
|
||||
return nil, ErrAccessDenied
|
||||
}
|
||||
|
||||
return session, nil
|
||||
@ -1033,23 +1046,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
|
||||
rw.Header().Set("Content-Type", applicationJSON)
|
||||
rw.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) validateGroups(groups []string) bool {
|
||||
if len(p.AllowedGroups) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
allowedGroups := map[string]struct{}{}
|
||||
|
||||
for _, group := range p.AllowedGroups {
|
||||
allowedGroups[group] = struct{}{}
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if _, ok := allowedGroups[group]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
|
||||
return nil, err
|
||||
}
|
||||
pcTest.proxy.provider = &TestProvider{
|
||||
ValidToken: opts.providerValidateCookieResponse,
|
||||
ProviderData: &providers.ProviderData{},
|
||||
ValidToken: opts.providerValidateCookieResponse,
|
||||
}
|
||||
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups)
|
||||
|
||||
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
||||
// access_token validation.
|
||||
@ -1284,7 +1286,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pcTest.proxy.provider = &TestProvider{
|
||||
ValidToken: true,
|
||||
ProviderData: &providers.ProviderData{},
|
||||
ValidToken: true,
|
||||
}
|
||||
|
||||
pcTest.validateUser = true
|
||||
@ -1376,7 +1379,8 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pcTest.proxy.provider = &TestProvider{
|
||||
ValidToken: true,
|
||||
ProviderData: &providers.ProviderData{},
|
||||
ValidToken: true,
|
||||
}
|
||||
|
||||
pcTest.validateUser = true
|
||||
@ -1455,7 +1459,8 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pcTest.proxy.provider = &TestProvider{
|
||||
ValidToken: true,
|
||||
ProviderData: &providers.ProviderData{},
|
||||
ValidToken: true,
|
||||
}
|
||||
|
||||
pcTest.validateUser = true
|
||||
|
@ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
||||
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
|
||||
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
|
||||
|
||||
p.SetAllowedGroups(o.AllowedGroups)
|
||||
|
||||
provider := providers.New(o.ProviderType, p)
|
||||
if provider == nil {
|
||||
msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType))
|
||||
@ -255,7 +257,13 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
||||
if err != nil {
|
||||
msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON)
|
||||
} else {
|
||||
p.SetGroupRestriction(o.GoogleGroups, o.GoogleAdminEmail, file)
|
||||
groups := o.AllowedGroups
|
||||
// Backwards compatibility with `--google-group` option
|
||||
if len(o.GoogleGroups) > 0 {
|
||||
groups = o.GoogleGroups
|
||||
p.SetAllowedGroups(groups)
|
||||
}
|
||||
p.SetGroupRestriction(groups, o.GoogleAdminEmail, file)
|
||||
}
|
||||
}
|
||||
case *providers.BitbucketProvider:
|
||||
|
@ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
return nil, ErrMissingCode
|
||||
}
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
@ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
|
||||
created := time.Now()
|
||||
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
||||
s = &sessions.SessionState{
|
||||
|
||||
return &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &expires,
|
||||
RefreshToken: jsonResponse.RefreshToken,
|
||||
}
|
||||
return
|
||||
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
|
@ -25,10 +25,16 @@ import (
|
||||
// GoogleProvider represents an Google based Identity Provider
|
||||
type GoogleProvider struct {
|
||||
*ProviderData
|
||||
|
||||
RedeemRefreshURL *url.URL
|
||||
// GroupValidator is a function that determines if the passed email is in
|
||||
// the configured Google group.
|
||||
GroupValidator func(string) bool
|
||||
|
||||
// groupValidator is a function that determines if the user in the passed
|
||||
// session is a member of any of the configured Google groups.
|
||||
//
|
||||
// This hits the Google API for each group, so it is called on Redeem &
|
||||
// Refresh. `Authorize` uses the results of this saved in `session.Groups`
|
||||
// Since it is called on every request.
|
||||
groupValidator func(*sessions.SessionState) bool
|
||||
}
|
||||
|
||||
var _ Provider = (*GoogleProvider)(nil)
|
||||
@ -84,9 +90,9 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
|
||||
})
|
||||
return &GoogleProvider{
|
||||
ProviderData: p,
|
||||
// Set a default GroupValidator to just always return valid (true), it will
|
||||
// Set a default groupValidator to just always return valid (true), it will
|
||||
// be overwritten if we configured a Google group restriction.
|
||||
GroupValidator: func(email string) bool {
|
||||
groupValidator: func(*sessions.SessionState) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
@ -118,14 +124,13 @@ func claimsFromIDToken(idToken string) (*claims, error) {
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
return nil, ErrMissingCode
|
||||
}
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
@ -155,12 +160,13 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
||||
|
||||
c, err := claimsFromIDToken(jsonResponse.IDToken)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
||||
s = &sessions.SessionState{
|
||||
|
||||
return &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: &created,
|
||||
@ -168,18 +174,40 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
||||
RefreshToken: jsonResponse.RefreshToken,
|
||||
Email: c.Email,
|
||||
User: c.Subject,
|
||||
}
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EnrichSessionState checks the listed Google Groups configured and adds any
|
||||
// that the user is a member of to session.Groups.
|
||||
func (p *GoogleProvider) EnrichSessionState(ctx context.Context, s *sessions.SessionState) error {
|
||||
// TODO (@NickMeves) - Move to pure EnrichSessionState logic and stop
|
||||
// reusing legacy `groupValidator`.
|
||||
//
|
||||
// This is called here to get the validator to do the `session.Groups`
|
||||
// populating logic.
|
||||
p.groupValidator(s)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetGroupRestriction configures the GoogleProvider to restrict access to the
|
||||
// specified group(s). AdminEmail has to be an administrative email on the domain that is
|
||||
// checked. CredentialsFile is the path to a json file containing a Google service
|
||||
// account credentials.
|
||||
//
|
||||
// TODO (@NickMeves) - Unit Test this OR refactor away from groupValidator func
|
||||
func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) {
|
||||
adminService := getAdminService(adminEmail, credentialsReader)
|
||||
p.GroupValidator = func(email string) bool {
|
||||
return userInGroup(adminService, groups, email)
|
||||
p.groupValidator = func(s *sessions.SessionState) bool {
|
||||
// Reset our saved Groups in case membership changed
|
||||
// This is used by `Authorize` on every request
|
||||
s.Groups = make([]string, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
if userInGroup(adminService, group, s.Email) {
|
||||
s.Groups = append(s.Groups, group)
|
||||
}
|
||||
}
|
||||
return len(s.Groups) > 0
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,52 +231,41 @@ func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Serv
|
||||
return adminService
|
||||
}
|
||||
|
||||
func userInGroup(service *admin.Service, groups []string, email string) bool {
|
||||
for _, group := range groups {
|
||||
// Use the HasMember API to checking for the user's presence in each group or nested subgroups
|
||||
req := service.Members.HasMember(group, email)
|
||||
func userInGroup(service *admin.Service, group string, email string) bool {
|
||||
// Use the HasMember API to checking for the user's presence in each group or nested subgroups
|
||||
req := service.Members.HasMember(group, email)
|
||||
r, err := req.Do()
|
||||
if err == nil {
|
||||
return r.IsMember
|
||||
}
|
||||
|
||||
gerr, ok := err.(*googleapi.Error)
|
||||
switch {
|
||||
case ok && gerr.Code == 404:
|
||||
logger.Errorf("error checking membership in group %s: group does not exist", group)
|
||||
case ok && gerr.Code == 400:
|
||||
// It is possible for Members.HasMember to return false even if the email is a group member.
|
||||
// One case that can cause this is if the user email is from a different domain than the group,
|
||||
// e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error
|
||||
// from the HasMember API. In that case, attempt to query the member object directly from the group.
|
||||
req := service.Members.Get(group, email)
|
||||
r, err := req.Do()
|
||||
if err != nil {
|
||||
gerr, ok := err.(*googleapi.Error)
|
||||
switch {
|
||||
case ok && gerr.Code == 404:
|
||||
logger.Errorf("error checking membership in group %s: group does not exist", group)
|
||||
case ok && gerr.Code == 400:
|
||||
// It is possible for Members.HasMember to return false even if the email is a group member.
|
||||
// One case that can cause this is if the user email is from a different domain than the group,
|
||||
// e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error
|
||||
// from the HasMember API. In that case, attempt to query the member object directly from the group.
|
||||
req := service.Members.Get(group, email)
|
||||
r, err := req.Do()
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group)
|
||||
continue
|
||||
}
|
||||
|
||||
// If the non-domain user is found within the group, still verify that they are "ACTIVE".
|
||||
// Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN").
|
||||
if r.Status == "ACTIVE" {
|
||||
return true
|
||||
}
|
||||
default:
|
||||
logger.Errorf("error checking group membership: %v", err)
|
||||
}
|
||||
continue
|
||||
logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group)
|
||||
return false
|
||||
}
|
||||
if r.IsMember {
|
||||
|
||||
// If the non-domain user is found within the group, still verify that they are "ACTIVE".
|
||||
// Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN").
|
||||
if r.Status == "ACTIVE" {
|
||||
return true
|
||||
}
|
||||
default:
|
||||
logger.Errorf("error checking group membership: %v", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateGroup validates that the provided email exists in the configured Google
|
||||
// group(s).
|
||||
func (p *GoogleProvider) ValidateGroup(email string) bool {
|
||||
return p.GroupValidator(email)
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
@ -261,8 +278,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
||||
return false, err
|
||||
}
|
||||
|
||||
// TODO (@NickMeves) - Align Group authorization needs with other providers'
|
||||
// behavior in the `RefreshSession` case.
|
||||
//
|
||||
// re-check that the user is in the proper google group(s)
|
||||
if !p.ValidateGroup(s.Email) {
|
||||
if !p.groupValidator(s) {
|
||||
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/assert"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
@ -109,21 +110,50 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
|
||||
assert.Equal(t, "refresh12345", session.RefreshToken)
|
||||
}
|
||||
|
||||
func TestGoogleProviderValidateGroup(t *testing.T) {
|
||||
p := newGoogleProvider()
|
||||
p.GroupValidator = func(email string) bool {
|
||||
return email == "michael.bland@gsa.gov"
|
||||
}
|
||||
assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov"))
|
||||
p.GroupValidator = func(email string) bool {
|
||||
return email != "michael.bland@gsa.gov"
|
||||
}
|
||||
assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov"))
|
||||
}
|
||||
func TestGoogleProviderGroupValidator(t *testing.T) {
|
||||
const sessionEmail = "michael.bland@gsa.gov"
|
||||
|
||||
func TestGoogleProviderWithoutValidateGroup(t *testing.T) {
|
||||
p := newGoogleProvider()
|
||||
assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov"))
|
||||
testCases := map[string]struct {
|
||||
session *sessions.SessionState
|
||||
validatorFunc func(*sessions.SessionState) bool
|
||||
expectedAuthZ bool
|
||||
}{
|
||||
"Email is authorized with groupValidator": {
|
||||
session: &sessions.SessionState{
|
||||
Email: sessionEmail,
|
||||
},
|
||||
validatorFunc: func(s *sessions.SessionState) bool {
|
||||
return s.Email == sessionEmail
|
||||
},
|
||||
expectedAuthZ: true,
|
||||
},
|
||||
"Email is denied with groupValidator": {
|
||||
session: &sessions.SessionState{
|
||||
Email: sessionEmail,
|
||||
},
|
||||
validatorFunc: func(s *sessions.SessionState) bool {
|
||||
return s.Email != sessionEmail
|
||||
},
|
||||
expectedAuthZ: false,
|
||||
},
|
||||
"Default does no authorization checks": {
|
||||
session: &sessions.SessionState{
|
||||
Email: sessionEmail,
|
||||
},
|
||||
validatorFunc: nil,
|
||||
expectedAuthZ: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
p := newGoogleProvider()
|
||||
if tc.validatorFunc != nil {
|
||||
p.groupValidator = tc.validatorFunc
|
||||
}
|
||||
g.Expect(p.groupValidator(tc.session)).To(Equal(tc.expectedAuthZ))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -196,7 +226,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestGoogleProviderUserInGroup(t *testing.T) {
|
||||
func TestGoogleProvider_userInGroup(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" {
|
||||
fmt.Fprintln(w, `{"isMember": true}`)
|
||||
@ -233,18 +263,19 @@ func TestGoogleProviderUserInGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
service, err := admin.NewService(ctx, option.WithHTTPClient(client))
|
||||
assert.NoError(t, err)
|
||||
|
||||
service.BasePath = ts.URL
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
result := userInGroup(service, []string{"group@example.com"}, "member-in-domain@example.com")
|
||||
result := userInGroup(service, "group@example.com", "member-in-domain@example.com")
|
||||
assert.True(t, result)
|
||||
|
||||
result = userInGroup(service, []string{"group@example.com"}, "member-out-of-domain@otherexample.com")
|
||||
result = userInGroup(service, "group@example.com", "member-out-of-domain@otherexample.com")
|
||||
assert.True(t, result)
|
||||
|
||||
result = userInGroup(service, []string{"group@example.com"}, "non-member-in-domain@example.com")
|
||||
result = userInGroup(service, "group@example.com", "non-member-in-domain@example.com")
|
||||
assert.False(t, result)
|
||||
|
||||
result = userInGroup(service, []string{"group@example.com"}, "non-member-out-of-domain@otherexample.com")
|
||||
result = userInGroup(service, "group@example.com", "non-member-out-of-domain@otherexample.com")
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
@ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
return nil, ErrMissingCode
|
||||
}
|
||||
|
||||
claims := &jwt.StandardClaims{
|
||||
@ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
||||
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
|
||||
ss, err := token.SignedString(p.JWTKey)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
@ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
||||
// check nonce here
|
||||
err = checkNonce(jsonResponse.IDToken, p)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the email address
|
||||
var email string
|
||||
email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
||||
|
||||
// Store the data that we found in the session state
|
||||
s = &sessions.SessionState{
|
||||
return &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &expires,
|
||||
Email: email,
|
||||
}
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetLoginURL overrides GetLoginURL to add login.gov parameters
|
||||
|
@ -26,6 +26,10 @@ type ProviderData struct {
|
||||
ClientSecretFile string
|
||||
Scope string
|
||||
Prompt string
|
||||
|
||||
// Universal Group authorization data structure
|
||||
// any provider can set to consume
|
||||
AllowedGroups map[string]struct{}
|
||||
}
|
||||
|
||||
// Data returns the ProviderData
|
||||
@ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) {
|
||||
return string(fileClientSecret), nil
|
||||
}
|
||||
|
||||
// SetAllowedGroups organizes a group list into the AllowedGroups map
|
||||
// to be consumed by Authorize implementations
|
||||
func (p *ProviderData) SetAllowedGroups(groups []string) {
|
||||
p.AllowedGroups = make(map[string]struct{}, len(groups))
|
||||
for _, group := range groups {
|
||||
p.AllowedGroups[group] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
type providerDefaults struct {
|
||||
name string
|
||||
loginURL *url.URL
|
||||
|
@ -19,18 +19,21 @@ var (
|
||||
// implementation method that doesn't have sensible defaults
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
|
||||
// ErrMissingCode is returned when a Redeem method is called with an empty
|
||||
// code
|
||||
ErrMissingCode = errors.New("missing code")
|
||||
|
||||
_ Provider = (*ProviderData)(nil)
|
||||
)
|
||||
|
||||
// Redeem provides a default implementation of the OAuth2 token redemption process
|
||||
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
return nil, ErrMissingCode
|
||||
}
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
@ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
}
|
||||
err = result.UnmarshalInto(&jsonResponse)
|
||||
if err == nil {
|
||||
s = &sessions.SessionState{
|
||||
return &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
}
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
||||
var v url.Values
|
||||
v, err = url.ParseQuery(string(result.Body()))
|
||||
values, err := url.ParseQuery(string(result.Body()))
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
if a := v.Get("access_token"); a != "" {
|
||||
if token := values.Get("access_token"); token != "" {
|
||||
created := time.Now()
|
||||
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
|
||||
} else {
|
||||
err = fmt.Errorf("no access token found %s", result.Body())
|
||||
return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil
|
||||
}
|
||||
return
|
||||
|
||||
return nil, fmt.Errorf("no access token found %s", result.Body())
|
||||
}
|
||||
|
||||
// GetLoginURL with typical oauth parameters
|
||||
@ -92,18 +92,28 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta
|
||||
return "", ErrNotImplemented
|
||||
}
|
||||
|
||||
// ValidateGroup validates that the provided email exists in the configured provider
|
||||
// email group(s).
|
||||
func (p *ProviderData) ValidateGroup(_ string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// EnrichSessionState is called after Redeem to allow providers to enrich session fields
|
||||
// such as User, Email, Groups with provider specific API calls.
|
||||
func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authorize performs global authorization on an authenticated session.
|
||||
// This is not used for fine-grained per route authorization rules.
|
||||
func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if len(p.AllowedGroups) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for _, group := range s.Groups {
|
||||
if _, ok := p.AllowedGroups[group]; ok {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
|
||||
return validateToken(ctx, p, s.AccessToken, nil)
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) {
|
||||
s := &sessions.SessionState{}
|
||||
assert.NoError(t, p.EnrichSessionState(context.Background(), s))
|
||||
}
|
||||
|
||||
func TestProviderDataAuthorize(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
allowedGroups []string
|
||||
groups []string
|
||||
expectedAuthZ bool
|
||||
}{
|
||||
{
|
||||
name: "NoAllowedGroups",
|
||||
allowedGroups: []string{},
|
||||
groups: []string{},
|
||||
expectedAuthZ: true,
|
||||
},
|
||||
{
|
||||
name: "NoAllowedGroupsUserHasGroups",
|
||||
allowedGroups: []string{},
|
||||
groups: []string{"foo", "bar"},
|
||||
expectedAuthZ: true,
|
||||
},
|
||||
{
|
||||
name: "UserInAllowedGroup",
|
||||
allowedGroups: []string{"foo"},
|
||||
groups: []string{"foo", "bar"},
|
||||
expectedAuthZ: true,
|
||||
},
|
||||
{
|
||||
name: "UserNotInAllowedGroup",
|
||||
allowedGroups: []string{"bar"},
|
||||
groups: []string{"baz", "foo"},
|
||||
expectedAuthZ: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
session := &sessions.SessionState{
|
||||
Groups: tc.groups,
|
||||
}
|
||||
p := &ProviderData{}
|
||||
p.SetAllowedGroups(tc.allowedGroups)
|
||||
|
||||
authorized, err := p.Authorize(context.Background(), session)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
g.Expect(authorized).To(Equal(tc.expectedAuthZ))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -13,8 +13,8 @@ type Provider interface {
|
||||
// DEPRECATED: Migrate to EnrichSessionState
|
||||
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
||||
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
||||
ValidateGroup(string) bool
|
||||
EnrichSessionState(ctx context.Context, s *sessions.SessionState) error
|
||||
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool
|
||||
GetLoginURL(redirectURI, finalRedirect string) string
|
||||
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||
|
Loading…
x
Reference in New Issue
Block a user