1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-19 21:27:58 +02:00

Merge pull request #797 from grnhse/refactor-provider-authz

Centralize Provider authorization interface method
This commit is contained in:
Joel Speed 2020-11-12 19:38:55 +00:00 committed by GitHub
commit c377466411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 285 additions and 150 deletions

View File

@ -6,6 +6,11 @@
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication.
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped.
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this
- Either `--google-group` or the new `--allowed-group` will work for Google now (`--google-group` will be used if both are set)
- Group membership lists will be passed to the backend with the `X-Forwarded-Groups` header
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
- Previously, group membership was only checked on session creation and refresh.
- [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) `--skip-auth-route` is (almost) backwards compatible with `--skip-auth-regex`
- We are marking `--skip-auth-regex` as DEPRECATED and will remove it in the next major version.
- If your regex contains an `=` and you want it for all methods, you will need to add a leading `=` (this is the area where `--skip-auth-regex` doesn't port perfectly)
@ -18,6 +23,9 @@
## Breaking Changes
- [#911](https://github.com/oauth2-proxy/oauth2-rpoxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google".
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
- Previously, group membership was only checked on session creation and refresh.
- [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass
- [#800](https://github.com/oauth2-proxy/oauth2-proxy/pull/800) Fix import path for v7. The import path has changed to support the go get installation.
- You can now `go get github.com/oauth2-proxy/oauth2-proxy/v7` to get the latest `v7` version of OAuth2 Proxy
@ -40,6 +48,7 @@
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves)
- [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves)
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) Integrate new header injectors into project (@JoelSpeed)
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Create universal Authorization behavior across providers (@NickMeves)
- [#898](https://github.com/oauth2-proxy/oauth2-proxy/pull/898) Migrate documentation to Docusaurus (@JoelSpeed)
- [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock)
- [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed)

View File

@ -42,6 +42,9 @@ var (
// ErrNeedsLogin means the user should be redirected to the login page
ErrNeedsLogin = errors.New("redirect to login page")
// ErrAccessDenied means the user should receive a 401 Unauthorized response
ErrAccessDenied = errors.New("access denied")
// Used to check final redirects are not susceptible to open redirects.
// Matches //, /\ and both of these with whitespace in between (eg / / or / \).
invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`)
@ -105,7 +108,6 @@ type OAuthProxy struct {
trustedIPs *ip.NetSet
Banner string
Footer string
AllowedGroups []string
sessionChain alice.Chain
headersChain alice.Chain
@ -219,7 +221,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
Banner: opts.Banner,
Footer: opts.Footer,
SignInMessage: buildSignInMessage(opts),
AllowedGroups: opts.AllowedGroups,
basicAuthValidator: basicAuthValidator,
displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm,
@ -396,7 +397,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
if code == "" {
return nil, errors.New("missing code")
return nil, providers.ErrMissingCode
}
redirectURI := p.GetRedirectURI(host)
s, err := p.provider.Redeem(ctx, redirectURI, code)
@ -909,11 +910,15 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
}
// set cookie, or deny
if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) {
authorized, err := p.provider.Authorize(req.Context(), session)
if err != nil {
logger.Errorf("Error with authorization: %v", err)
}
if p.Validator(session.Email) && authorized {
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
err := p.SaveSession(rw, req, session)
if err != nil {
logger.Printf("Error saving session state for %s: %v", remoteAddr, err)
logger.Errorf("Error saving session state for %s: %v", remoteAddr, err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
return
}
@ -967,6 +972,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
p.SignInPage(rw, req, http.StatusForbidden)
}
case ErrAccessDenied:
p.ErrorPage(rw, http.StatusUnauthorized, "Permission Denied", "Unauthorized")
default:
// unknown error
logger.Errorf("Unexpected internal error: %v", err)
@ -977,7 +985,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
}
// getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
// Returns nil, ErrNeedsLogin if user needs to login.
// Returns:
// - `nil, ErrNeedsLogin` if user needs to login.
// - `nil, ErrAccessDenied` if the authenticated user is not authorized
// Set-Cookie headers may be set on the response as a side-effect of calling this method.
func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
var session *sessionsapi.SessionState
@ -991,17 +1001,20 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
return nil, ErrNeedsLogin
}
invalidEmail := session != nil && session.Email != "" && !p.Validator(session.Email)
invalidGroups := session != nil && !p.validateGroups(session.Groups)
invalidEmail := session.Email != "" && !p.Validator(session.Email)
authorized, err := p.provider.Authorize(req.Context(), session)
if err != nil {
logger.Errorf("Error with authorization: %v", err)
}
if invalidEmail || invalidGroups {
logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
if invalidEmail || !authorized {
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authorization via session: removing session %s", session)
// Invalid session, clear it
err := p.ClearSessionCookie(rw, req)
if err != nil {
logger.Printf("Error clearing session cookie: %v", err)
logger.Errorf("Error clearing session cookie: %v", err)
}
return nil, ErrNeedsLogin
return nil, ErrAccessDenied
}
return session, nil
@ -1033,23 +1046,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
rw.Header().Set("Content-Type", applicationJSON)
rw.WriteHeader(code)
}
func (p *OAuthProxy) validateGroups(groups []string) bool {
if len(p.AllowedGroups) == 0 {
return true
}
allowedGroups := map[string]struct{}{}
for _, group := range p.AllowedGroups {
allowedGroups[group] = struct{}{}
}
for _, group := range groups {
if _, ok := allowedGroups[group]; ok {
return true
}
}
return false
}

View File

@ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
return nil, err
}
pcTest.proxy.provider = &TestProvider{
ValidToken: opts.providerValidateCookieResponse,
ProviderData: &providers.ProviderData{},
ValidToken: opts.providerValidateCookieResponse,
}
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups)
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation.
@ -1284,7 +1286,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
t.Fatal(err)
}
pcTest.proxy.provider = &TestProvider{
ValidToken: true,
ProviderData: &providers.ProviderData{},
ValidToken: true,
}
pcTest.validateUser = true
@ -1376,7 +1379,8 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
t.Fatal(err)
}
pcTest.proxy.provider = &TestProvider{
ValidToken: true,
ProviderData: &providers.ProviderData{},
ValidToken: true,
}
pcTest.validateUser = true
@ -1455,7 +1459,8 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
t.Fatal(err)
}
pcTest.proxy.provider = &TestProvider{
ValidToken: true,
ProviderData: &providers.ProviderData{},
ValidToken: true,
}
pcTest.validateUser = true

View File

@ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
p.SetAllowedGroups(o.AllowedGroups)
provider := providers.New(o.ProviderType, p)
if provider == nil {
msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType))
@ -255,7 +257,13 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
if err != nil {
msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON)
} else {
p.SetGroupRestriction(o.GoogleGroups, o.GoogleAdminEmail, file)
groups := o.AllowedGroups
// Backwards compatibility with `--google-group` option
if len(o.GoogleGroups) > 0 {
groups = o.GoogleGroups
p.SetAllowedGroups(groups)
}
p.SetGroupRestriction(groups, o.GoogleAdminEmail, file)
}
}
case *providers.BitbucketProvider:

View File

@ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" {
err = errors.New("missing code")
return
return nil, ErrMissingCode
}
clientSecret, err := p.GetClientSecret()
if err != nil {
return
return nil, err
}
params := url.Values{}
@ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
created := time.Now()
expires := time.Unix(jsonResponse.ExpiresOn, 0)
s = &sessions.SessionState{
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: &created,
ExpiresOn: &expires,
RefreshToken: jsonResponse.RefreshToken,
}
return
}, nil
}
// RefreshSessionIfNeeded checks if the session has expired and uses the

View File

@ -25,10 +25,16 @@ import (
// GoogleProvider represents an Google based Identity Provider
type GoogleProvider struct {
*ProviderData
RedeemRefreshURL *url.URL
// GroupValidator is a function that determines if the passed email is in
// the configured Google group.
GroupValidator func(string) bool
// groupValidator is a function that determines if the user in the passed
// session is a member of any of the configured Google groups.
//
// This hits the Google API for each group, so it is called on Redeem &
// Refresh. `Authorize` uses the results of this saved in `session.Groups`
// Since it is called on every request.
groupValidator func(*sessions.SessionState) bool
}
var _ Provider = (*GoogleProvider)(nil)
@ -84,9 +90,9 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
})
return &GoogleProvider{
ProviderData: p,
// Set a default GroupValidator to just always return valid (true), it will
// Set a default groupValidator to just always return valid (true), it will
// be overwritten if we configured a Google group restriction.
GroupValidator: func(email string) bool {
groupValidator: func(*sessions.SessionState) bool {
return true
},
}
@ -118,14 +124,13 @@ func claimsFromIDToken(idToken string) (*claims, error) {
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" {
err = errors.New("missing code")
return
return nil, ErrMissingCode
}
clientSecret, err := p.GetClientSecret()
if err != nil {
return
return nil, err
}
params := url.Values{}
@ -155,12 +160,13 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
c, err := claimsFromIDToken(jsonResponse.IDToken)
if err != nil {
return
return nil, err
}
created := time.Now()
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
s = &sessions.SessionState{
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: &created,
@ -168,18 +174,40 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
RefreshToken: jsonResponse.RefreshToken,
Email: c.Email,
User: c.Subject,
}
return
}, nil
}
// EnrichSessionState checks the listed Google Groups configured and adds any
// that the user is a member of to session.Groups.
func (p *GoogleProvider) EnrichSessionState(ctx context.Context, s *sessions.SessionState) error {
// TODO (@NickMeves) - Move to pure EnrichSessionState logic and stop
// reusing legacy `groupValidator`.
//
// This is called here to get the validator to do the `session.Groups`
// populating logic.
p.groupValidator(s)
return nil
}
// SetGroupRestriction configures the GoogleProvider to restrict access to the
// specified group(s). AdminEmail has to be an administrative email on the domain that is
// checked. CredentialsFile is the path to a json file containing a Google service
// account credentials.
//
// TODO (@NickMeves) - Unit Test this OR refactor away from groupValidator func
func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) {
adminService := getAdminService(adminEmail, credentialsReader)
p.GroupValidator = func(email string) bool {
return userInGroup(adminService, groups, email)
p.groupValidator = func(s *sessions.SessionState) bool {
// Reset our saved Groups in case membership changed
// This is used by `Authorize` on every request
s.Groups = make([]string, 0, len(groups))
for _, group := range groups {
if userInGroup(adminService, group, s.Email) {
s.Groups = append(s.Groups, group)
}
}
return len(s.Groups) > 0
}
}
@ -203,52 +231,41 @@ func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Serv
return adminService
}
func userInGroup(service *admin.Service, groups []string, email string) bool {
for _, group := range groups {
// Use the HasMember API to checking for the user's presence in each group or nested subgroups
req := service.Members.HasMember(group, email)
func userInGroup(service *admin.Service, group string, email string) bool {
// Use the HasMember API to checking for the user's presence in each group or nested subgroups
req := service.Members.HasMember(group, email)
r, err := req.Do()
if err == nil {
return r.IsMember
}
gerr, ok := err.(*googleapi.Error)
switch {
case ok && gerr.Code == 404:
logger.Errorf("error checking membership in group %s: group does not exist", group)
case ok && gerr.Code == 400:
// It is possible for Members.HasMember to return false even if the email is a group member.
// One case that can cause this is if the user email is from a different domain than the group,
// e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error
// from the HasMember API. In that case, attempt to query the member object directly from the group.
req := service.Members.Get(group, email)
r, err := req.Do()
if err != nil {
gerr, ok := err.(*googleapi.Error)
switch {
case ok && gerr.Code == 404:
logger.Errorf("error checking membership in group %s: group does not exist", group)
case ok && gerr.Code == 400:
// It is possible for Members.HasMember to return false even if the email is a group member.
// One case that can cause this is if the user email is from a different domain than the group,
// e.g. "member@otherdomain.com" in the group "group@mydomain.com" will result in a 400 error
// from the HasMember API. In that case, attempt to query the member object directly from the group.
req := service.Members.Get(group, email)
r, err := req.Do()
if err != nil {
logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group)
continue
}
// If the non-domain user is found within the group, still verify that they are "ACTIVE".
// Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN").
if r.Status == "ACTIVE" {
return true
}
default:
logger.Errorf("error checking group membership: %v", err)
}
continue
logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group)
return false
}
if r.IsMember {
// If the non-domain user is found within the group, still verify that they are "ACTIVE".
// Do not count the user as belonging to a group if they have another status ("ARCHIVED", "SUSPENDED", or "UNKNOWN").
if r.Status == "ACTIVE" {
return true
}
default:
logger.Errorf("error checking group membership: %v", err)
}
return false
}
// ValidateGroup validates that the provided email exists in the configured Google
// group(s).
func (p *GoogleProvider) ValidateGroup(email string) bool {
return p.GroupValidator(email)
}
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
@ -261,8 +278,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
return false, err
}
// TODO (@NickMeves) - Align Group authorization needs with other providers'
// behavior in the `RefreshSession` case.
//
// re-check that the user is in the proper google group(s)
if !p.ValidateGroup(s.Email) {
if !p.groupValidator(s) {
return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
}

View File

@ -10,6 +10,7 @@ import (
"net/url"
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert"
admin "google.golang.org/api/admin/directory/v1"
@ -109,21 +110,50 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
assert.Equal(t, "refresh12345", session.RefreshToken)
}
func TestGoogleProviderValidateGroup(t *testing.T) {
p := newGoogleProvider()
p.GroupValidator = func(email string) bool {
return email == "michael.bland@gsa.gov"
}
assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov"))
p.GroupValidator = func(email string) bool {
return email != "michael.bland@gsa.gov"
}
assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov"))
}
func TestGoogleProviderGroupValidator(t *testing.T) {
const sessionEmail = "michael.bland@gsa.gov"
func TestGoogleProviderWithoutValidateGroup(t *testing.T) {
p := newGoogleProvider()
assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov"))
testCases := map[string]struct {
session *sessions.SessionState
validatorFunc func(*sessions.SessionState) bool
expectedAuthZ bool
}{
"Email is authorized with groupValidator": {
session: &sessions.SessionState{
Email: sessionEmail,
},
validatorFunc: func(s *sessions.SessionState) bool {
return s.Email == sessionEmail
},
expectedAuthZ: true,
},
"Email is denied with groupValidator": {
session: &sessions.SessionState{
Email: sessionEmail,
},
validatorFunc: func(s *sessions.SessionState) bool {
return s.Email != sessionEmail
},
expectedAuthZ: false,
},
"Default does no authorization checks": {
session: &sessions.SessionState{
Email: sessionEmail,
},
validatorFunc: nil,
expectedAuthZ: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
g := NewWithT(t)
p := newGoogleProvider()
if tc.validatorFunc != nil {
p.groupValidator = tc.validatorFunc
}
g.Expect(p.groupValidator(tc.session)).To(Equal(tc.expectedAuthZ))
})
}
}
//
@ -196,7 +226,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
}
func TestGoogleProviderUserInGroup(t *testing.T) {
func TestGoogleProvider_userInGroup(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" {
fmt.Fprintln(w, `{"isMember": true}`)
@ -233,18 +263,19 @@ func TestGoogleProviderUserInGroup(t *testing.T) {
ctx := context.Background()
service, err := admin.NewService(ctx, option.WithHTTPClient(client))
assert.NoError(t, err)
service.BasePath = ts.URL
assert.Equal(t, nil, err)
result := userInGroup(service, []string{"group@example.com"}, "member-in-domain@example.com")
result := userInGroup(service, "group@example.com", "member-in-domain@example.com")
assert.True(t, result)
result = userInGroup(service, []string{"group@example.com"}, "member-out-of-domain@otherexample.com")
result = userInGroup(service, "group@example.com", "member-out-of-domain@otherexample.com")
assert.True(t, result)
result = userInGroup(service, []string{"group@example.com"}, "non-member-in-domain@example.com")
result = userInGroup(service, "group@example.com", "non-member-in-domain@example.com")
assert.False(t, result)
result = userInGroup(service, []string{"group@example.com"}, "non-member-out-of-domain@otherexample.com")
result = userInGroup(service, "group@example.com", "non-member-out-of-domain@otherexample.com")
assert.False(t, result)
}

View File

@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/rsa"
"errors"
"fmt"
"math/rand"
"net/url"
@ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" {
err = errors.New("missing code")
return
return nil, ErrMissingCode
}
claims := &jwt.StandardClaims{
@ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
ss, err := token.SignedString(p.JWTKey)
if err != nil {
return
return nil, err
}
params := url.Values{}
@ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
// check nonce here
err = checkNonce(jsonResponse.IDToken, p)
if err != nil {
return
return nil, err
}
// Get the email address
var email string
email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String())
if err != nil {
return
return nil, err
}
created := time.Now()
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
// Store the data that we found in the session state
s = &sessions.SessionState{
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: &created,
ExpiresOn: &expires,
Email: email,
}
return
}, nil
}
// GetLoginURL overrides GetLoginURL to add login.gov parameters

View File

@ -26,6 +26,10 @@ type ProviderData struct {
ClientSecretFile string
Scope string
Prompt string
// Universal Group authorization data structure
// any provider can set to consume
AllowedGroups map[string]struct{}
}
// Data returns the ProviderData
@ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) {
return string(fileClientSecret), nil
}
// SetAllowedGroups organizes a group list into the AllowedGroups map
// to be consumed by Authorize implementations
func (p *ProviderData) SetAllowedGroups(groups []string) {
p.AllowedGroups = make(map[string]struct{}, len(groups))
for _, group := range groups {
p.AllowedGroups[group] = struct{}{}
}
}
type providerDefaults struct {
name string
loginURL *url.URL

View File

@ -19,18 +19,21 @@ var (
// implementation method that doesn't have sensible defaults
ErrNotImplemented = errors.New("not implemented")
// ErrMissingCode is returned when a Redeem method is called with an empty
// code
ErrMissingCode = errors.New("missing code")
_ Provider = (*ProviderData)(nil)
)
// Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" {
err = errors.New("missing code")
return
return nil, ErrMissingCode
}
clientSecret, err := p.GetClientSecret()
if err != nil {
return
return nil, err
}
params := url.Values{}
@ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
}
err = result.UnmarshalInto(&jsonResponse)
if err == nil {
s = &sessions.SessionState{
return &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
}
return
}, nil
}
var v url.Values
v, err = url.ParseQuery(string(result.Body()))
values, err := url.ParseQuery(string(result.Body()))
if err != nil {
return
return nil, err
}
if a := v.Get("access_token"); a != "" {
if token := values.Get("access_token"); token != "" {
created := time.Now()
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
} else {
err = fmt.Errorf("no access token found %s", result.Body())
return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil
}
return
return nil, fmt.Errorf("no access token found %s", result.Body())
}
// GetLoginURL with typical oauth parameters
@ -92,18 +92,28 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta
return "", ErrNotImplemented
}
// ValidateGroup validates that the provided email exists in the configured provider
// email group(s).
func (p *ProviderData) ValidateGroup(_ string) bool {
return true
}
// EnrichSessionState is called after Redeem to allow providers to enrich session fields
// such as User, Email, Groups with provider specific API calls.
func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error {
return nil
}
// Authorize performs global authorization on an authenticated session.
// This is not used for fine-grained per route authorization rules.
func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) {
if len(p.AllowedGroups) == 0 {
return true, nil
}
for _, group := range s.Groups {
if _, ok := p.AllowedGroups[group]; ok {
return true, nil
}
}
return false, nil
}
// ValidateSessionState validates the AccessToken
func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, nil)

View File

@ -7,6 +7,7 @@ import (
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert"
)
@ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) {
s := &sessions.SessionState{}
assert.NoError(t, p.EnrichSessionState(context.Background(), s))
}
func TestProviderDataAuthorize(t *testing.T) {
testCases := []struct {
name string
allowedGroups []string
groups []string
expectedAuthZ bool
}{
{
name: "NoAllowedGroups",
allowedGroups: []string{},
groups: []string{},
expectedAuthZ: true,
},
{
name: "NoAllowedGroupsUserHasGroups",
allowedGroups: []string{},
groups: []string{"foo", "bar"},
expectedAuthZ: true,
},
{
name: "UserInAllowedGroup",
allowedGroups: []string{"foo"},
groups: []string{"foo", "bar"},
expectedAuthZ: true,
},
{
name: "UserNotInAllowedGroup",
allowedGroups: []string{"bar"},
groups: []string{"baz", "foo"},
expectedAuthZ: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
g := NewWithT(t)
session := &sessions.SessionState{
Groups: tc.groups,
}
p := &ProviderData{}
p.SetAllowedGroups(tc.allowedGroups)
authorized, err := p.Authorize(context.Background(), session)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(authorized).To(Equal(tc.expectedAuthZ))
})
}
}

View File

@ -13,8 +13,8 @@ type Provider interface {
// DEPRECATED: Migrate to EnrichSessionState
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
ValidateGroup(string) bool
EnrichSessionState(ctx context.Context, s *sessions.SessionState) error
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)