mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-21 21:47:11 +02:00
Merge pull request #797 from grnhse/refactor-provider-authz
Centralize Provider authorization interface method
This commit is contained in:
commit
c377466411
@ -6,6 +6,11 @@
|
|||||||
|
|
||||||
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication.
|
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication.
|
||||||
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped.
|
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped.
|
||||||
|
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this
|
||||||
|
- Either `--google-group` or the new `--allowed-group` will work for Google now (`--google-group` will be used if both are set)
|
||||||
|
- Group membership lists will be passed to the backend with the `X-Forwarded-Groups` header
|
||||||
|
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
|
||||||
|
- Previously, group membership was only checked on session creation and refresh.
|
||||||
- [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) `--skip-auth-route` is (almost) backwards compatible with `--skip-auth-regex`
|
- [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) `--skip-auth-route` is (almost) backwards compatible with `--skip-auth-regex`
|
||||||
- We are marking `--skip-auth-regex` as DEPRECATED and will remove it in the next major version.
|
- We are marking `--skip-auth-regex` as DEPRECATED and will remove it in the next major version.
|
||||||
- If your regex contains an `=` and you want it for all methods, you will need to add a leading `=` (this is the area where `--skip-auth-regex` doesn't port perfectly)
|
- If your regex contains an `=` and you want it for all methods, you will need to add a leading `=` (this is the area where `--skip-auth-regex` doesn't port perfectly)
|
||||||
@ -18,6 +23,9 @@
|
|||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
|
|
||||||
- [#911](https://github.com/oauth2-proxy/oauth2-rpoxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google".
|
- [#911](https://github.com/oauth2-proxy/oauth2-rpoxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google".
|
||||||
|
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow
|
||||||
|
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
|
||||||
|
- Previously, group membership was only checked on session creation and refresh.
|
||||||
- [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass
|
- [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass
|
||||||
- [#800](https://github.com/oauth2-proxy/oauth2-proxy/pull/800) Fix import path for v7. The import path has changed to support the go get installation.
|
- [#800](https://github.com/oauth2-proxy/oauth2-proxy/pull/800) Fix import path for v7. The import path has changed to support the go get installation.
|
||||||
- You can now `go get github.com/oauth2-proxy/oauth2-proxy/v7` to get the latest `v7` version of OAuth2 Proxy
|
- You can now `go get github.com/oauth2-proxy/oauth2-proxy/v7` to get the latest `v7` version of OAuth2 Proxy
|
||||||
@ -40,6 +48,7 @@
|
|||||||
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves)
|
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves)
|
||||||
- [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves)
|
- [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves)
|
||||||
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) Integrate new header injectors into project (@JoelSpeed)
|
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) Integrate new header injectors into project (@JoelSpeed)
|
||||||
|
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Create universal Authorization behavior across providers (@NickMeves)
|
||||||
- [#898](https://github.com/oauth2-proxy/oauth2-proxy/pull/898) Migrate documentation to Docusaurus (@JoelSpeed)
|
- [#898](https://github.com/oauth2-proxy/oauth2-proxy/pull/898) Migrate documentation to Docusaurus (@JoelSpeed)
|
||||||
- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock)
|
- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock)
|
||||||
- [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed)
|
- [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed)
|
||||||
|
@ -42,6 +42,9 @@ var (
|
|||||||
// ErrNeedsLogin means the user should be redirected to the login page
|
// ErrNeedsLogin means the user should be redirected to the login page
|
||||||
ErrNeedsLogin = errors.New("redirect to login page")
|
ErrNeedsLogin = errors.New("redirect to login page")
|
||||||
|
|
||||||
|
// ErrAccessDenied means the user should receive a 401 Unauthorized response
|
||||||
|
ErrAccessDenied = errors.New("access denied")
|
||||||
|
|
||||||
// Used to check final redirects are not susceptible to open redirects.
|
// Used to check final redirects are not susceptible to open redirects.
|
||||||
// Matches //, /\ and both of these with whitespace in between (eg / / or / \).
|
// Matches //, /\ and both of these with whitespace in between (eg / / or / \).
|
||||||
invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`)
|
invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`)
|
||||||
@ -105,7 +108,6 @@ type OAuthProxy struct {
|
|||||||
trustedIPs *ip.NetSet
|
trustedIPs *ip.NetSet
|
||||||
Banner string
|
Banner string
|
||||||
Footer string
|
Footer string
|
||||||
AllowedGroups []string
|
|
||||||
|
|
||||||
sessionChain alice.Chain
|
sessionChain alice.Chain
|
||||||
headersChain alice.Chain
|
headersChain alice.Chain
|
||||||
@ -219,7 +221,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
Banner: opts.Banner,
|
Banner: opts.Banner,
|
||||||
Footer: opts.Footer,
|
Footer: opts.Footer,
|
||||||
SignInMessage: buildSignInMessage(opts),
|
SignInMessage: buildSignInMessage(opts),
|
||||||
AllowedGroups: opts.AllowedGroups,
|
|
||||||
|
|
||||||
basicAuthValidator: basicAuthValidator,
|
basicAuthValidator: basicAuthValidator,
|
||||||
displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm,
|
displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm,
|
||||||
@ -396,7 +397,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
|
|||||||
|
|
||||||
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
|
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, errors.New("missing code")
|
return nil, providers.ErrMissingCode
|
||||||
}
|
}
|
||||||
redirectURI := p.GetRedirectURI(host)
|
redirectURI := p.GetRedirectURI(host)
|
||||||
s, err := p.provider.Redeem(ctx, redirectURI, code)
|
s, err := p.provider.Redeem(ctx, redirectURI, code)
|
||||||
@ -909,11 +910,15 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// set cookie, or deny
|
// set cookie, or deny
|
||||||
if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) {
|
authorized, err := p.provider.Authorize(req.Context(), session)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error with authorization: %v", err)
|
||||||
|
}
|
||||||
|
if p.Validator(session.Email) && authorized {
|
||||||
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
|
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
|
||||||
err := p.SaveSession(rw, req, session)
|
err := p.SaveSession(rw, req, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("Error saving session state for %s: %v", remoteAddr, err)
|
logger.Errorf("Error saving session state for %s: %v", remoteAddr, err)
|
||||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -967,6 +972,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
|||||||
p.SignInPage(rw, req, http.StatusForbidden)
|
p.SignInPage(rw, req, http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case ErrAccessDenied:
|
||||||
|
p.ErrorPage(rw, http.StatusUnauthorized, "Permission Denied", "Unauthorized")
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// unknown error
|
// unknown error
|
||||||
logger.Errorf("Unexpected internal error: %v", err)
|
logger.Errorf("Unexpected internal error: %v", err)
|
||||||
@ -977,7 +985,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
|
// getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
|
||||||
// Returns nil, ErrNeedsLogin if user needs to login.
|
// Returns:
|
||||||
|
// - `nil, ErrNeedsLogin` if user needs to login.
|
||||||
|
// - `nil, ErrAccessDenied` if the authenticated user is not authorized
|
||||||
// Set-Cookie headers may be set on the response as a side-effect of calling this method.
|
// Set-Cookie headers may be set on the response as a side-effect of calling this method.
|
||||||
func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
|
func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
|
||||||
var session *sessionsapi.SessionState
|
var session *sessionsapi.SessionState
|
||||||
@ -991,17 +1001,20 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
|
|||||||
return nil, ErrNeedsLogin
|
return nil, ErrNeedsLogin
|
||||||
}
|
}
|
||||||
|
|
||||||
invalidEmail := session != nil && session.Email != "" && !p.Validator(session.Email)
|
invalidEmail := session.Email != "" && !p.Validator(session.Email)
|
||||||
invalidGroups := session != nil && !p.validateGroups(session.Groups)
|
authorized, err := p.provider.Authorize(req.Context(), session)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error with authorization: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if invalidEmail || invalidGroups {
|
if invalidEmail || !authorized {
|
||||||
logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
|
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authorization via session: removing session %s", session)
|
||||||
// Invalid session, clear it
|
// Invalid session, clear it
|
||||||
err := p.ClearSessionCookie(rw, req)
|
err := p.ClearSessionCookie(rw, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("Error clearing session cookie: %v", err)
|
logger.Errorf("Error clearing session cookie: %v", err)
|
||||||
}
|
}
|
||||||
return nil, ErrNeedsLogin
|
return nil, ErrAccessDenied
|
||||||
}
|
}
|
||||||
|
|
||||||
return session, nil
|
return session, nil
|
||||||
@ -1033,23 +1046,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
|
|||||||
rw.Header().Set("Content-Type", applicationJSON)
|
rw.Header().Set("Content-Type", applicationJSON)
|
||||||
rw.WriteHeader(code)
|
rw.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) validateGroups(groups []string) bool {
|
|
||||||
if len(p.AllowedGroups) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
allowedGroups := map[string]struct{}{}
|
|
||||||
|
|
||||||
for _, group := range p.AllowedGroups {
|
|
||||||
allowedGroups[group] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, group := range groups {
|
|
||||||
if _, ok := allowedGroups[group]; ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
@ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pcTest.proxy.provider = &TestProvider{
|
pcTest.proxy.provider = &TestProvider{
|
||||||
|
ProviderData: &providers.ProviderData{},
|
||||||
ValidToken: opts.providerValidateCookieResponse,
|
ValidToken: opts.providerValidateCookieResponse,
|
||||||
}
|
}
|
||||||
|
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups)
|
||||||
|
|
||||||
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
||||||
// access_token validation.
|
// access_token validation.
|
||||||
@ -1284,6 +1286,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
pcTest.proxy.provider = &TestProvider{
|
pcTest.proxy.provider = &TestProvider{
|
||||||
|
ProviderData: &providers.ProviderData{},
|
||||||
ValidToken: true,
|
ValidToken: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1376,6 +1379,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
pcTest.proxy.provider = &TestProvider{
|
pcTest.proxy.provider = &TestProvider{
|
||||||
|
ProviderData: &providers.ProviderData{},
|
||||||
ValidToken: true,
|
ValidToken: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1455,6 +1459,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
pcTest.proxy.provider = &TestProvider{
|
pcTest.proxy.provider = &TestProvider{
|
||||||
|
ProviderData: &providers.ProviderData{},
|
||||||
ValidToken: true,
|
ValidToken: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
|||||||
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
|
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
|
||||||
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
|
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
|
||||||
|
|
||||||
|
p.SetAllowedGroups(o.AllowedGroups)
|
||||||
|
|
||||||
provider := providers.New(o.ProviderType, p)
|
provider := providers.New(o.ProviderType, p)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType))
|
msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType))
|
||||||
@ -255,7 +257,13 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON)
|
msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON)
|
||||||
} else {
|
} else {
|
||||||
p.SetGroupRestriction(o.GoogleGroups, o.GoogleAdminEmail, file)
|
groups := o.AllowedGroups
|
||||||
|
// Backwards compatibility with `--google-group` option
|
||||||
|
if len(o.GoogleGroups) > 0 {
|
||||||
|
groups = o.GoogleGroups
|
||||||
|
p.SetAllowedGroups(groups)
|
||||||
|
}
|
||||||
|
p.SetGroupRestriction(groups, o.GoogleAdminEmail, file)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *providers.BitbucketProvider:
|
case *providers.BitbucketProvider:
|
||||||
|
@ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
return nil, ErrMissingCode
|
||||||
return
|
|
||||||
}
|
}
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
@ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
|
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
expires := time.Unix(jsonResponse.ExpiresOn, 0)
|
||||||
s = &sessions.SessionState{
|
|
||||||
|
return &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
CreatedAt: &created,
|
CreatedAt: &created,
|
||||||
ExpiresOn: &expires,
|
ExpiresOn: &expires,
|
||||||
RefreshToken: jsonResponse.RefreshToken,
|
RefreshToken: jsonResponse.RefreshToken,
|
||||||
}
|
}, nil
|
||||||
return
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
|
@ -25,10 +25,16 @@ import (
|
|||||||
// GoogleProvider represents an Google based Identity Provider
|
// GoogleProvider represents an Google based Identity Provider
|
||||||
type GoogleProvider struct {
|
type GoogleProvider struct {
|
||||||
*ProviderData
|
*ProviderData
|
||||||
|
|
||||||
RedeemRefreshURL *url.URL
|
RedeemRefreshURL *url.URL
|
||||||
// GroupValidator is a function that determines if the passed email is in
|
|
||||||
// the configured Google group.
|
// groupValidator is a function that determines if the user in the passed
|
||||||
GroupValidator func(string) bool
|
// session is a member of any of the configured Google groups.
|
||||||
|
//
|
||||||
|
// This hits the Google API for each group, so it is called on Redeem &
|
||||||
|
// Refresh. `Authorize` uses the results of this saved in `session.Groups`
|
||||||
|
// Since it is called on every request.
|
||||||
|
groupValidator func(*sessions.SessionState) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Provider = (*GoogleProvider)(nil)
|
var _ Provider = (*GoogleProvider)(nil)
|
||||||
@ -84,9 +90,9 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
|
|||||||
})
|
})
|
||||||
return &GoogleProvider{
|
return &GoogleProvider{
|
||||||
ProviderData: p,
|
ProviderData: p,
|
||||||
// Set a default GroupValidator to just always return valid (true), it will
|
// Set a default groupValidator to just always return valid (true), it will
|
||||||
// be overwritten if we configured a Google group restriction.
|
// be overwritten if we configured a Google group restriction.
|
||||||
GroupValidator: func(email string) bool {
|
groupValidator: func(*sessions.SessionState) bool {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -118,14 +124,13 @@ func claimsFromIDToken(idToken string) (*claims, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
return nil, ErrMissingCode
|
||||||
return
|
|
||||||
}
|
}
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
@ -155,12 +160,13 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
|||||||
|
|
||||||
c, err := claimsFromIDToken(jsonResponse.IDToken)
|
c, err := claimsFromIDToken(jsonResponse.IDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
||||||
s = &sessions.SessionState{
|
|
||||||
|
return &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
CreatedAt: &created,
|
CreatedAt: &created,
|
||||||
@ -168,18 +174,40 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
|||||||
RefreshToken: jsonResponse.RefreshToken,
|
RefreshToken: jsonResponse.RefreshToken,
|
||||||
Email: c.Email,
|
Email: c.Email,
|
||||||
User: c.Subject,
|
User: c.Subject,
|
||||||
}
|
}, nil
|
||||||
return
|
}
|
||||||
|
|
||||||
|
// EnrichSessionState checks the listed Google Groups configured and adds any
|
||||||
|
// that the user is a member of to session.Groups.
|
||||||
|
func (p *GoogleProvider) EnrichSessionState(ctx context.Context, s *sessions.SessionState) error {
|
||||||
|
// TODO (@NickMeves) - Move to pure EnrichSessionState logic and stop
|
||||||
|
// reusing legacy `groupValidator`.
|
||||||
|
//
|
||||||
|
// This is called here to get the validator to do the `session.Groups`
|
||||||
|
// populating logic.
|
||||||
|
p.groupValidator(s)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetGroupRestriction configures the GoogleProvider to restrict access to the
|
// SetGroupRestriction configures the GoogleProvider to restrict access to the
|
||||||
// specified group(s). AdminEmail has to be an administrative email on the domain that is
|
// specified group(s). AdminEmail has to be an administrative email on the domain that is
|
||||||
// checked. CredentialsFile is the path to a json file containing a Google service
|
// checked. CredentialsFile is the path to a json file containing a Google service
|
||||||
// account credentials.
|
// account credentials.
|
||||||
|
//
|
||||||
|
// TODO (@NickMeves) - Unit Test this OR refactor away from groupValidator func
|
||||||
func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) {
|
func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) {
|
||||||
adminService := getAdminService(adminEmail, credentialsReader)
|
adminService := getAdminService(adminEmail, credentialsReader)
|
||||||
p.GroupValidator = func(email string) bool {
|
p.groupValidator = func(s *sessions.SessionState) bool {
|
||||||
return userInGroup(adminService, groups, email)
|
// Reset our saved Groups in case membership changed
|
||||||
|
// This is used by `Authorize` on every request
|
||||||
|
s.Groups = make([]string, 0, len(groups))
|
||||||
|
for _, group := range groups {
|
||||||
|
if userInGroup(adminService, group, s.Email) {
|
||||||
|
s.Groups = append(s.Groups, group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(s.Groups) > 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,12 +231,14 @@ func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Serv
|
|||||||
return adminService
|
return adminService
|
||||||
}
|
}
|
||||||
|
|
||||||
func userInGroup(service *admin.Service, groups []string, email string) bool {
|
func userInGroup(service *admin.Service, group string, email string) bool {
|
||||||
for _, group := range groups {
|
|
||||||
// Use the HasMember API to checking for the user's presence in each group or nested subgroups
|
// Use the HasMember API to checking for the user's presence in each group or nested subgroups
|
||||||
req := service.Members.HasMember(group, email)
|
req := service.Members.HasMember(group, email)
|
||||||
r, err := req.Do()
|
r, err := req.Do()
|
||||||
if err != nil {
|
if err == nil {
|
||||||
|
return r.IsMember
|
||||||
|
}
|
||||||
|
|
||||||
gerr, ok := err.(*googleapi.Error)
|
gerr, ok := err.(*googleapi.Error)
|
||||||
switch {
|
switch {
|
||||||
case ok && gerr.Code == 404:
|
case ok && gerr.Code == 404:
|
||||||
@ -220,10 +250,9 @@ func userInGroup(service *admin.Service, groups []string, email string) bool {
|
|||||||
// from the HasMember API. In that case, attempt to query the member object directly from the group.
|
// from the HasMember API. In that case, attempt to query the member object directly from the group.
|
||||||
req := service.Members.Get(group, email)
|
req := service.Members.Get(group, email)
|
||||||
r, err := req.Do()
|
r, err := req.Do()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group)
|
logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group)
|
||||||
continue
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the non-domain user is found within the group, still verify that they are "ACTIVE".
|
// If the non-domain user is found within the group, still verify that they are "ACTIVE".
|
||||||
@ -234,21 +263,9 @@ func userInGroup(service *admin.Service, groups []string, email string) bool {
|
|||||||
default:
|
default:
|
||||||
logger.Errorf("error checking group membership: %v", err)
|
logger.Errorf("error checking group membership: %v", err)
|
||||||
}
|
}
|
||||||
continue
|
|
||||||
}
|
|
||||||
if r.IsMember {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateGroup validates that the provided email exists in the configured Google
|
|
||||||
// group(s).
|
|
||||||
func (p *GoogleProvider) ValidateGroup(email string) bool {
|
|
||||||
return p.GroupValidator(email)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new ID token if required
|
// RefreshToken to fetch a new ID token if required
|
||||||
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
@ -261,8 +278,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
|
|||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO (@NickMeves) - Align Group authorization needs with other providers'
|
||||||
|
// behavior in the `RefreshSession` case.
|
||||||
|
//
|
||||||
// re-check that the user is in the proper google group(s)
|
// re-check that the user is in the proper google group(s)
|
||||||
if !p.ValidateGroup(s.Email) {
|
if !p.groupValidator(s) {
|
||||||
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
|
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
admin "google.golang.org/api/admin/directory/v1"
|
admin "google.golang.org/api/admin/directory/v1"
|
||||||
@ -109,21 +110,50 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
|
|||||||
assert.Equal(t, "refresh12345", session.RefreshToken)
|
assert.Equal(t, "refresh12345", session.RefreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGoogleProviderValidateGroup(t *testing.T) {
|
func TestGoogleProviderGroupValidator(t *testing.T) {
|
||||||
p := newGoogleProvider()
|
const sessionEmail = "michael.bland@gsa.gov"
|
||||||
p.GroupValidator = func(email string) bool {
|
|
||||||
return email == "michael.bland@gsa.gov"
|
|
||||||
}
|
|
||||||
assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov"))
|
|
||||||
p.GroupValidator = func(email string) bool {
|
|
||||||
return email != "michael.bland@gsa.gov"
|
|
||||||
}
|
|
||||||
assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGoogleProviderWithoutValidateGroup(t *testing.T) {
|
testCases := map[string]struct {
|
||||||
|
session *sessions.SessionState
|
||||||
|
validatorFunc func(*sessions.SessionState) bool
|
||||||
|
expectedAuthZ bool
|
||||||
|
}{
|
||||||
|
"Email is authorized with groupValidator": {
|
||||||
|
session: &sessions.SessionState{
|
||||||
|
Email: sessionEmail,
|
||||||
|
},
|
||||||
|
validatorFunc: func(s *sessions.SessionState) bool {
|
||||||
|
return s.Email == sessionEmail
|
||||||
|
},
|
||||||
|
expectedAuthZ: true,
|
||||||
|
},
|
||||||
|
"Email is denied with groupValidator": {
|
||||||
|
session: &sessions.SessionState{
|
||||||
|
Email: sessionEmail,
|
||||||
|
},
|
||||||
|
validatorFunc: func(s *sessions.SessionState) bool {
|
||||||
|
return s.Email != sessionEmail
|
||||||
|
},
|
||||||
|
expectedAuthZ: false,
|
||||||
|
},
|
||||||
|
"Default does no authorization checks": {
|
||||||
|
session: &sessions.SessionState{
|
||||||
|
Email: sessionEmail,
|
||||||
|
},
|
||||||
|
validatorFunc: nil,
|
||||||
|
expectedAuthZ: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
p := newGoogleProvider()
|
p := newGoogleProvider()
|
||||||
assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov"))
|
if tc.validatorFunc != nil {
|
||||||
|
p.groupValidator = tc.validatorFunc
|
||||||
|
}
|
||||||
|
g.Expect(p.groupValidator(tc.session)).To(Equal(tc.expectedAuthZ))
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -196,7 +226,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGoogleProviderUserInGroup(t *testing.T) {
|
func TestGoogleProvider_userInGroup(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" {
|
if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" {
|
||||||
fmt.Fprintln(w, `{"isMember": true}`)
|
fmt.Fprintln(w, `{"isMember": true}`)
|
||||||
@ -233,18 +263,19 @@ func TestGoogleProviderUserInGroup(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
service, err := admin.NewService(ctx, option.WithHTTPClient(client))
|
service, err := admin.NewService(ctx, option.WithHTTPClient(client))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
service.BasePath = ts.URL
|
service.BasePath = ts.URL
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
|
|
||||||
result := userInGroup(service, []string{"group@example.com"}, "member-in-domain@example.com")
|
result := userInGroup(service, "group@example.com", "member-in-domain@example.com")
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
|
|
||||||
result = userInGroup(service, []string{"group@example.com"}, "member-out-of-domain@otherexample.com")
|
result = userInGroup(service, "group@example.com", "member-out-of-domain@otherexample.com")
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
|
|
||||||
result = userInGroup(service, []string{"group@example.com"}, "non-member-in-domain@example.com")
|
result = userInGroup(service, "group@example.com", "non-member-in-domain@example.com")
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
|
|
||||||
result = userInGroup(service, []string{"group@example.com"}, "non-member-out-of-domain@otherexample.com")
|
result = userInGroup(service, "group@example.com", "non-member-out-of-domain@otherexample.com")
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
return nil, ErrMissingCode
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
claims := &jwt.StandardClaims{
|
claims := &jwt.StandardClaims{
|
||||||
@ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
|||||||
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
|
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
|
||||||
ss, err := token.SignedString(p.JWTKey)
|
ss, err := token.SignedString(p.JWTKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
@ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
|||||||
// check nonce here
|
// check nonce here
|
||||||
err = checkNonce(jsonResponse.IDToken, p)
|
err = checkNonce(jsonResponse.IDToken, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the email address
|
// Get the email address
|
||||||
var email string
|
var email string
|
||||||
email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
|
email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
|
||||||
|
|
||||||
// Store the data that we found in the session state
|
// Store the data that we found in the session state
|
||||||
s = &sessions.SessionState{
|
return &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
CreatedAt: &created,
|
CreatedAt: &created,
|
||||||
ExpiresOn: &expires,
|
ExpiresOn: &expires,
|
||||||
Email: email,
|
Email: email,
|
||||||
}
|
}, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLoginURL overrides GetLoginURL to add login.gov parameters
|
// GetLoginURL overrides GetLoginURL to add login.gov parameters
|
||||||
|
@ -26,6 +26,10 @@ type ProviderData struct {
|
|||||||
ClientSecretFile string
|
ClientSecretFile string
|
||||||
Scope string
|
Scope string
|
||||||
Prompt string
|
Prompt string
|
||||||
|
|
||||||
|
// Universal Group authorization data structure
|
||||||
|
// any provider can set to consume
|
||||||
|
AllowedGroups map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Data returns the ProviderData
|
// Data returns the ProviderData
|
||||||
@ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) {
|
|||||||
return string(fileClientSecret), nil
|
return string(fileClientSecret), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAllowedGroups organizes a group list into the AllowedGroups map
|
||||||
|
// to be consumed by Authorize implementations
|
||||||
|
func (p *ProviderData) SetAllowedGroups(groups []string) {
|
||||||
|
p.AllowedGroups = make(map[string]struct{}, len(groups))
|
||||||
|
for _, group := range groups {
|
||||||
|
p.AllowedGroups[group] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type providerDefaults struct {
|
type providerDefaults struct {
|
||||||
name string
|
name string
|
||||||
loginURL *url.URL
|
loginURL *url.URL
|
||||||
|
@ -19,18 +19,21 @@ var (
|
|||||||
// implementation method that doesn't have sensible defaults
|
// implementation method that doesn't have sensible defaults
|
||||||
ErrNotImplemented = errors.New("not implemented")
|
ErrNotImplemented = errors.New("not implemented")
|
||||||
|
|
||||||
|
// ErrMissingCode is returned when a Redeem method is called with an empty
|
||||||
|
// code
|
||||||
|
ErrMissingCode = errors.New("missing code")
|
||||||
|
|
||||||
_ Provider = (*ProviderData)(nil)
|
_ Provider = (*ProviderData)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem provides a default implementation of the OAuth2 token redemption process
|
// Redeem provides a default implementation of the OAuth2 token redemption process
|
||||||
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
return nil, ErrMissingCode
|
||||||
return
|
|
||||||
}
|
}
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
@ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
}
|
}
|
||||||
err = result.UnmarshalInto(&jsonResponse)
|
err = result.UnmarshalInto(&jsonResponse)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s = &sessions.SessionState{
|
return &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
}
|
}, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var v url.Values
|
values, err := url.ParseQuery(string(result.Body()))
|
||||||
v, err = url.ParseQuery(string(result.Body()))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
if a := v.Get("access_token"); a != "" {
|
if token := values.Get("access_token"); token != "" {
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
|
return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil
|
||||||
} else {
|
|
||||||
err = fmt.Errorf("no access token found %s", result.Body())
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
return nil, fmt.Errorf("no access token found %s", result.Body())
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLoginURL with typical oauth parameters
|
// GetLoginURL with typical oauth parameters
|
||||||
@ -92,18 +92,28 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta
|
|||||||
return "", ErrNotImplemented
|
return "", ErrNotImplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateGroup validates that the provided email exists in the configured provider
|
|
||||||
// email group(s).
|
|
||||||
func (p *ProviderData) ValidateGroup(_ string) bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnrichSessionState is called after Redeem to allow providers to enrich session fields
|
// EnrichSessionState is called after Redeem to allow providers to enrich session fields
|
||||||
// such as User, Email, Groups with provider specific API calls.
|
// such as User, Email, Groups with provider specific API calls.
|
||||||
func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error {
|
func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Authorize performs global authorization on an authenticated session.
|
||||||
|
// This is not used for fine-grained per route authorization rules.
|
||||||
|
func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
|
if len(p.AllowedGroups) == 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range s.Groups {
|
||||||
|
if _, ok := p.AllowedGroups[group]; ok {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
|
func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
|
||||||
return validateToken(ctx, p, s.AccessToken, nil)
|
return validateToken(ctx, p, s.AccessToken, nil)
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) {
|
|||||||
s := &sessions.SessionState{}
|
s := &sessions.SessionState{}
|
||||||
assert.NoError(t, p.EnrichSessionState(context.Background(), s))
|
assert.NoError(t, p.EnrichSessionState(context.Background(), s))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProviderDataAuthorize(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
allowedGroups []string
|
||||||
|
groups []string
|
||||||
|
expectedAuthZ bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "NoAllowedGroups",
|
||||||
|
allowedGroups: []string{},
|
||||||
|
groups: []string{},
|
||||||
|
expectedAuthZ: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NoAllowedGroupsUserHasGroups",
|
||||||
|
allowedGroups: []string{},
|
||||||
|
groups: []string{"foo", "bar"},
|
||||||
|
expectedAuthZ: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserInAllowedGroup",
|
||||||
|
allowedGroups: []string{"foo"},
|
||||||
|
groups: []string{"foo", "bar"},
|
||||||
|
expectedAuthZ: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserNotInAllowedGroup",
|
||||||
|
allowedGroups: []string{"bar"},
|
||||||
|
groups: []string{"baz", "foo"},
|
||||||
|
expectedAuthZ: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
|
||||||
|
session := &sessions.SessionState{
|
||||||
|
Groups: tc.groups,
|
||||||
|
}
|
||||||
|
p := &ProviderData{}
|
||||||
|
p.SetAllowedGroups(tc.allowedGroups)
|
||||||
|
|
||||||
|
authorized, err := p.Authorize(context.Background(), session)
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
g.Expect(authorized).To(Equal(tc.expectedAuthZ))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -13,8 +13,8 @@ type Provider interface {
|
|||||||
// DEPRECATED: Migrate to EnrichSessionState
|
// DEPRECATED: Migrate to EnrichSessionState
|
||||||
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
|
||||||
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
|
||||||
ValidateGroup(string) bool
|
|
||||||
EnrichSessionState(ctx context.Context, s *sessions.SessionState) error
|
EnrichSessionState(ctx context.Context, s *sessions.SessionState) error
|
||||||
|
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||||
ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool
|
ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool
|
||||||
GetLoginURL(redirectURI, finalRedirect string) string
|
GetLoginURL(redirectURI, finalRedirect string) string
|
||||||
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user