diff --git a/CHANGELOG.md b/CHANGELOG.md index 783c6dc7..43a58a78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,14 @@ ## Important Notes +- [#616](https://github.com/oauth2-proxy/oauth2-proxy/pull/616) Ensure you have configured oauth2-proxy to use the `groups` scope. The user may be logged out initially as they may not currently have the `groups` claim however after going back through login process wil be authenticated. + ## Breaking Changes ## Changes since v6.1.1 - [#764](https://github.com/oauth2-proxy/oauth2-proxy/pull/764) Document bcrypt encryption for htpasswd (and hide SHA) (@lentzi90) +- [#616](https://github.com/oauth2-proxy/oauth2-proxy/pull/616) Add support to ensure user belongs in required groups when using the OIDC provider (@stefansedich) # v6.1.1 diff --git a/docs/configuration/configuration.md b/docs/configuration/configuration.md index 10e0afcc..0370cdf4 100644 --- a/docs/configuration/configuration.md +++ b/docs/configuration/configuration.md @@ -78,12 +78,13 @@ An example [oauth2-proxy.cfg]({{ site.gitweb }}/contrib/oauth2-proxy.cfg.example | `--insecure-oidc-skip-issuer-verification` | bool | allow the OIDC issuer URL to differ from the expected (currently required for Azure multi-tenant compatibility) | false | | `--oidc-issuer-url` | string | the OpenID Connect issuer URL, e.g. `"https://accounts.google.com"` | | | `--oidc-jwks-url` | string | OIDC JWKS URI for token verification; required if OIDC discovery is disabled | | +| `--oidc-groups-claim` | string | which claim contains the user groups | `"groups"` | | `--pass-access-token` | bool | pass OAuth access_token to upstream via X-Forwarded-Access-Token header | false | | `--pass-authorization-header` | bool | pass OIDC IDToken to upstream via Authorization Bearer header | false | | `--pass-basic-auth` | bool | pass HTTP Basic Auth, X-Forwarded-User, X-Forwarded-Email and X-Forwarded-Preferred-Username information to upstream | true | | `--prefer-email-to-user` | bool | Prefer to use the Email address as the Username when passing information to upstream. Will only use Username if Email is unavailable, e.g. htaccess authentication. Used in conjunction with `--pass-basic-auth` and `--pass-user-headers` | false | | `--pass-host-header` | bool | pass the request Host Header to upstream | true | -| `--pass-user-headers` | bool | pass X-Forwarded-User, X-Forwarded-Email and X-Forwarded-Preferred-Username information to upstream | true | +| `--pass-user-headers` | bool | pass X-Forwarded-User, X-Forwarded-Groups, X-Forwarded-Email and X-Forwarded-Preferred-Username information to upstream | true | | `--profile-url` | string | Profile access endpoint | | | `--prompt` | string | [OIDC prompt](https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest); if present, `approval-prompt` is ignored | `""` | | `--provider` | string | OAuth provider | google | @@ -112,7 +113,7 @@ An example [oauth2-proxy.cfg]({{ site.gitweb }}/contrib/oauth2-proxy.cfg.example | `--scope` | string | OAuth scope specification | | | `--session-cookie-minimal` | bool | strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only) | false | | `--session-store-type` | string | [Session data storage backend](configuration/sessions); redis or cookie | cookie | -| `--set-xauthrequest` | bool | set X-Auth-Request-User, X-Auth-Request-Email and X-Auth-Request-Preferred-Username response headers (useful in Nginx auth_request mode) | false | +| `--set-xauthrequest` | bool | set X-Auth-Request-User, X-Auth-Request-Groups, X-Auth-Request-Email and X-Auth-Request-Preferred-Username response headers (useful in Nginx auth_request mode) | false | | `--set-authorization-header` | bool | set Authorization Bearer response header (useful in Nginx auth_request mode) | false | | `--set-basic-auth` | bool | set HTTP Basic Auth information in response (useful in Nginx auth_request mode) | false | | `--signature-key` | string | GAP-Signature request signature key (algorithm:secretkey) | | @@ -131,6 +132,7 @@ An example [oauth2-proxy.cfg]({{ site.gitweb }}/contrib/oauth2-proxy.cfg.example | `--tls-key-file` | string | path to private key file | | | `--upstream` | string \| list | the http url(s) of the upstream endpoint, file:// paths for static files or `static://` for static response. Routing is based on the path | | | `--user-id-claim` | string | which claim contains the user ID | \["email"\] | +| `--allowed-group` | string \| list | restrict logins to members of this group (may be given multiple times) | | | `--validate-url` | string | Access token validation endpoint | | | `--version` | n/a | print version string | | | `--whitelist-domain` | string \| list | allowed domains for redirection after authentication. Prefix domain with a `.` to allow subdomains (e.g. `.example.com`) \[[2](#footnote2)\] | | diff --git a/oauthproxy.go b/oauthproxy.go index f4d3c496..17d106d5 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -102,6 +102,7 @@ type OAuthProxy struct { trustedIPs *ip.NetSet Banner string Footer string + AllowedGroups []string sessionChain alice.Chain } @@ -215,6 +216,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr Banner: opts.Banner, Footer: opts.Footer, SignInMessage: buildSignInMessage(opts), + AllowedGroups: opts.AllowedGroups, basicAuthValidator: basicAuthValidator, displayHtpasswdForm: basicAuthValidator != nil, @@ -888,7 +890,10 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R return nil, ErrNeedsLogin } - if session != nil && session.Email != "" && !p.Validator(session.Email) { + invalidEmail := session != nil && session.Email != "" && !p.Validator(session.Email) + invalidGroups := session != nil && !p.validateGroups(session.Groups) + + if invalidEmail || invalidGroups { logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) // Invalid session, clear it err := p.ClearSessionCookie(rw, req) @@ -942,6 +947,14 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req } else { req.Header.Del("X-Forwarded-Preferred-Username") } + + if len(session.Groups) > 0 { + for _, group := range session.Groups { + req.Header.Add("X-Forwarded-Groups", group) + } + } else { + req.Header.Del("X-Forwarded-Groups") + } } if p.SetXAuthRequest { @@ -964,6 +977,14 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req rw.Header().Del("X-Auth-Request-Access-Token") } } + + if len(session.Groups) > 0 { + for _, group := range session.Groups { + rw.Header().Add("X-Auth-Request-Groups", group) + } + } else { + rw.Header().Del("X-Auth-Request-Groups") + } } if p.PassAccessToken { @@ -1012,6 +1033,7 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req func (p *OAuthProxy) stripAuthHeaders(req *http.Request) { if p.PassBasicAuth { req.Header.Del("X-Forwarded-User") + req.Header.Del("X-Forwarded-Groups") req.Header.Del("X-Forwarded-Email") req.Header.Del("X-Forwarded-Preferred-Username") req.Header.Del("Authorization") @@ -1019,6 +1041,7 @@ func (p *OAuthProxy) stripAuthHeaders(req *http.Request) { if p.PassUserHeaders { req.Header.Del("X-Forwarded-User") + req.Header.Del("X-Forwarded-Groups") req.Header.Del("X-Forwarded-Email") req.Header.Del("X-Forwarded-Preferred-Username") } @@ -1049,3 +1072,23 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { rw.Header().Set("Content-Type", applicationJSON) rw.WriteHeader(code) } + +func (p *OAuthProxy) validateGroups(groups []string) bool { + if len(p.AllowedGroups) == 0 { + return true + } + + allowedGroups := map[string]struct{}{} + + for _, group := range p.AllowedGroups { + allowedGroups[group] = struct{}{} + } + + for _, group := range groups { + if _, ok := allowedGroups[group]; ok { + return true + } + } + + return false +} diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 83425274..395df820 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -592,6 +592,37 @@ func TestPassUserHeadersWithEmail(t *testing.T) { } } +func TestPassGroupsHeadersWithGroups(t *testing.T) { + opts := baseTestOptions() + err := validation.Validate(opts) + assert.NoError(t, err) + + const emailAddress = "john.doe@example.com" + const userName = "9fcab5c9b889a557" + + groups := []string{"a", "b"} + created := time.Now() + session := &sessions.SessionState{ + User: userName, + Groups: groups, + Email: emailAddress, + AccessToken: "oauth_token", + CreatedAt: &created, + } + { + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase0", nil) + proxy, err := NewOAuthProxy(opts, func(email string) bool { + return email == emailAddress + }) + if err != nil { + t.Fatal(err) + } + proxy.addHeadersForProxying(rw, req, session) + assert.Equal(t, groups, req.Header["X-Forwarded-Groups"]) + } +} + func TestStripAuthHeaders(t *testing.T) { testCases := map[string]struct { SkipAuthStripHeaders bool @@ -609,6 +640,7 @@ func TestStripAuthHeaders(t *testing.T) { PassAuthorization: false, StrippedHeaders: map[string]bool{ "X-Forwarded-User": true, + "X-Forwared-Groups": true, "X-Forwarded-Email": true, "X-Forwarded-Preferred-Username": true, "X-Forwarded-Access-Token": false, @@ -623,6 +655,7 @@ func TestStripAuthHeaders(t *testing.T) { PassAuthorization: false, StrippedHeaders: map[string]bool{ "X-Forwarded-User": true, + "X-Forwared-Groups": true, "X-Forwarded-Email": true, "X-Forwarded-Preferred-Username": true, "X-Forwarded-Access-Token": true, @@ -637,6 +670,7 @@ func TestStripAuthHeaders(t *testing.T) { PassAuthorization: false, StrippedHeaders: map[string]bool{ "X-Forwarded-User": true, + "X-Forwared-Groups": true, "X-Forwarded-Email": true, "X-Forwarded-Preferred-Username": true, "X-Forwarded-Access-Token": true, @@ -651,6 +685,7 @@ func TestStripAuthHeaders(t *testing.T) { PassAuthorization: true, StrippedHeaders: map[string]bool{ "X-Forwarded-User": false, + "X-Forwared-Groups": false, "X-Forwarded-Email": false, "X-Forwarded-Preferred-Username": false, "X-Forwarded-Access-Token": false, @@ -665,6 +700,7 @@ func TestStripAuthHeaders(t *testing.T) { PassAuthorization: false, StrippedHeaders: map[string]bool{ "X-Forwarded-User": false, + "X-Forwared-Groups": false, "X-Forwarded-Email": false, "X-Forwarded-Preferred-Username": false, "X-Forwarded-Access-Token": false, @@ -679,6 +715,7 @@ func TestStripAuthHeaders(t *testing.T) { PassAuthorization: false, StrippedHeaders: map[string]bool{ "X-Forwarded-User": false, + "X-Forwared-Groups": false, "X-Forwarded-Email": false, "X-Forwarded-Preferred-Username": false, "X-Forwarded-Access-Token": false, @@ -690,6 +727,7 @@ func TestStripAuthHeaders(t *testing.T) { initialHeaders := map[string]string{ "X-Forwarded-User": "9fcab5c9b889a557", "X-Forwarded-Email": "john.doe@example.com", + "X-Forwarded-Groups": "a,b,c", "X-Forwarded-Preferred-Username": "john.doe", "X-Forwarded-Access-Token": "AccessToken", "Authorization": "bearer IDToken", @@ -1333,6 +1371,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { pcTest.opts = baseTestOptions() pcTest.opts.SetXAuthRequest = true + pcTest.opts.AllowedGroups = []string{"oauth_groups"} err := validation.Validate(pcTest.opts) assert.NoError(t, err) @@ -1354,13 +1393,14 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { created := time.Now() startSession := &sessions.SessionState{ - User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} + User: "oauth_user", Groups: []string{"oauth_groups"}, Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} err = pcTest.SaveSession(startSession) assert.NoError(t, err) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) assert.Equal(t, "oauth_user", pcTest.rw.Header().Get("X-Auth-Request-User")) + assert.Equal(t, startSession.Groups, pcTest.rw.Header().Values("X-Auth-Request-Groups")) assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Get("X-Auth-Request-Email")) } @@ -2199,3 +2239,108 @@ func TestTrustedIPs(t *testing.T) { }) } } + +func TestProxyAllowedGroups(t *testing.T) { + tests := []struct { + name string + allowedGroups []string + groups []string + expectUnauthorized bool + }{ + {"NoAllowedGroups", []string{}, []string{}, false}, + {"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false}, + {"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false}, + {"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + emailAddress := "test" + created := time.Now() + + session := &sessions.SessionState{ + Groups: tt.groups, + Email: emailAddress, + AccessToken: "oauth_token", + CreatedAt: &created, + } + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + t.Cleanup(upstream.Close) + + test, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { + opts.AllowedGroups = tt.allowedGroups + opts.UpstreamServers = options.Upstreams{ + { + ID: upstream.URL, + Path: "/", + URI: upstream.URL, + }, + } + }) + if err != nil { + t.Fatal(err) + } + + test.req, _ = http.NewRequest("GET", "/", nil) + + test.req.Header.Add("accept", applicationJSON) + test.SaveSession(session) + test.proxy.ServeHTTP(test.rw, test.req) + + if tt.expectUnauthorized { + assert.Equal(t, http.StatusUnauthorized, test.rw.Code) + } else { + assert.Equal(t, http.StatusOK, test.rw.Code) + } + }) + } +} + +func TestAuthOnlyAllowedGroups(t *testing.T) { + tests := []struct { + name string + allowedGroups []string + groups []string + expectUnauthorized bool + }{ + {"NoAllowedGroups", []string{}, []string{}, false}, + {"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false}, + {"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false}, + {"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + emailAddress := "test" + created := time.Now() + + session := &sessions.SessionState{ + Groups: tt.groups, + Email: emailAddress, + AccessToken: "oauth_token", + CreatedAt: &created, + } + + test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { + opts.AllowedGroups = tt.allowedGroups + }) + if err != nil { + t.Fatal(err) + } + + err = test.SaveSession(session) + assert.NoError(t, err) + + test.proxy.ServeHTTP(test.rw, test.req) + + if tt.expectUnauthorized { + assert.Equal(t, http.StatusUnauthorized, test.rw.Code) + } else { + assert.Equal(t, http.StatusAccepted, test.rw.Code) + } + }) + } +} diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index b723b60b..e9f506cb 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -93,6 +93,7 @@ type Options struct { InsecureOIDCSkipIssuerVerification bool `flag:"insecure-oidc-skip-issuer-verification" cfg:"insecure_oidc_skip_issuer_verification"` SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery"` OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url"` + OIDCGroupsClaim string `flag:"oidc-groups-claim" cfg:"oidc_groups_claim"` LoginURL string `flag:"login-url" cfg:"login_url"` RedeemURL string `flag:"redeem-url" cfg:"redeem_url"` ProfileURL string `flag:"profile-url" cfg:"profile_url"` @@ -102,6 +103,7 @@ type Options struct { Prompt string `flag:"prompt" cfg:"prompt"` ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt"` // Deprecated by OIDC 1.0 UserIDClaim string `flag:"user-id-claim" cfg:"user_id_claim"` + AllowedGroups []string `flag:"allowed-group" cfg:"allowed_groups"` SignatureKey string `flag:"signature-key" cfg:"signature_key"` AcrValues string `flag:"acr-values" cfg:"acr_values"` @@ -167,6 +169,7 @@ func NewOptions() *Options { InsecureOIDCAllowUnverifiedEmail: false, SkipOIDCDiscovery: false, Logging: loggingDefaults(), + OIDCGroupsClaim: "groups", } } @@ -248,6 +251,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Bool("insecure-oidc-skip-issuer-verification", false, "Do not verify if issuer matches OIDC discovery URL") flagSet.Bool("skip-oidc-discovery", false, "Skip OIDC discovery and use manually supplied Endpoints") flagSet.String("oidc-jwks-url", "", "OpenID Connect JWKS URL (ie: https://www.googleapis.com/oauth2/v3/certs)") + flagSet.String("oidc-groups-claim", "groups", "which claim contains the user groups") flagSet.String("login-url", "", "Authentication endpoint") flagSet.String("redeem-url", "", "Token redemption endpoint") flagSet.String("profile-url", "", "Profile access endpoint") @@ -265,6 +269,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints") flagSet.String("user-id-claim", "email", "which claim contains the user ID") + flagSet.StringSlice("allowed-group", []string{}, "restrict logins to members of this group (may be given multiple times)") flagSet.AddFlagSet(cookieFlagSet()) flagSet.AddFlagSet(loggingFlagSet()) diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index e69c4db4..b10c347a 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "io/ioutil" + "reflect" "time" "unicode/utf8" @@ -24,6 +25,7 @@ type SessionState struct { RefreshToken string `json:",omitempty" msgpack:"rt,omitempty"` Email string `json:",omitempty" msgpack:"e,omitempty"` User string `json:",omitempty" msgpack:"u,omitempty"` + Groups []string `json:",omitempty" msgpack:"g,omitempty"` PreferredUsername string `json:",omitempty" msgpack:"pu,omitempty"` } @@ -61,6 +63,9 @@ func (s *SessionState) String() string { if s.RefreshToken != "" { o += " refresh_token:true" } + if len(s.Groups) > 0 { + o += fmt.Sprintf(" groups:%v", s.Groups) + } return o + "}" } @@ -233,7 +238,7 @@ func (s *SessionState) validate() error { } empty := new(SessionState) - if *s == *empty { + if reflect.DeepEqual(*s, *empty) { return errors.New("invalid empty session unmarshalled") } diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 08216b26..31005928 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -186,6 +186,17 @@ func TestEncodeAndDecodeSessionState(t *testing.T) { IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", ExpiresOn: &expires, }, + "With groups": { + Email: "username@example.com", + User: "username", + PreferredUsername: "preferred.username", + AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + CreatedAt: &created, + ExpiresOn: &expires, + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + Groups: []string{"group-a", "group-b"}, + }, } for _, secretSize := range []int{16, 24, 32} { diff --git a/pkg/validation/options.go b/pkg/validation/options.go index f9325cf0..75f678be 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -152,6 +152,10 @@ func Validate(o *options.Options) error { } if o.Scope == "" { o.Scope = "openid email profile" + + if len(o.AllowedGroups) > 0 { + o.Scope += " groups" + } } } @@ -279,6 +283,7 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { case *providers.OIDCProvider: p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail p.UserIDClaim = o.UserIDClaim + p.GroupsClaim = o.OIDCGroupsClaim if o.GetOIDCVerifier() == nil { msgs = append(msgs, "oidc provider requires an oidc issuer URL") } else { diff --git a/providers/oidc.go b/providers/oidc.go index b14e0b61..7162740f 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -22,6 +22,7 @@ type OIDCProvider struct { Verifier *oidc.IDTokenVerifier AllowUnverifiedEmail bool UserIDClaim string + GroupsClaim string } // NewOIDCProvider initiates a new OIDCProvider @@ -123,6 +124,7 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi s.IDToken = newSession.IDToken s.Email = newSession.Email s.User = newSession.User + s.Groups = newSession.Groups s.PreferredUsername = newSession.PreferredUsername } @@ -204,6 +206,7 @@ func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken * newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future newSession.User = claims.Subject + newSession.Groups = claims.Groups newSession.PreferredUsername = claims.PreferredUsername verifyEmail := (p.UserIDClaim == emailClaim) && !p.AllowUnverifiedEmail @@ -222,6 +225,7 @@ func (p *OIDCProvider) ValidateSessionState(ctx context.Context, s *sessions.Ses func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { claims := &OIDCClaims{} + // Extract default claims. if err := idToken.Claims(&claims); err != nil { return nil, fmt.Errorf("failed to parse default id_token claims: %v", err) @@ -236,6 +240,8 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. claims.UserID = fmt.Sprint(userID) } + claims.Groups = p.extractGroupsFromRawClaims(claims.rawClaims) + // userID claim was not present or was empty in the ID Token if claims.UserID == "" { // BearerToken case, allow empty UserID @@ -273,10 +279,27 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. return claims, nil } +func (p *OIDCProvider) extractGroupsFromRawClaims(rawClaims map[string]interface{}) []string { + groups := []string{} + + rawGroups, ok := rawClaims[p.GroupsClaim].([]interface{}) + if rawGroups != nil && ok { + for _, rawGroup := range rawGroups { + group, ok := rawGroup.(string) + if ok { + groups = append(groups, group) + } + } + } + + return groups +} + type OIDCClaims struct { rawClaims map[string]interface{} UserID string Subject string `json:"sub"` Verified *bool `json:"email_verified"` PreferredUsername string `json:"preferred_username"` + Groups []string } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 9e96752d..5e91418b 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -29,10 +29,12 @@ const clientID = "https://test.myapp.com" const secret = "secret" type idTokenClaims struct { - Name string `json:"name,omitempty"` - Email string `json:"email,omitempty"` - Phone string `json:"phone_number,omitempty"` - Picture string `json:"picture,omitempty"` + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` + Phone string `json:"phone_number,omitempty"` + Picture string `json:"picture,omitempty"` + Groups []string `json:"groups,omitempty"` + OtherGroups []string `json:"other_groups,omitempty"` jwt.StandardClaims } @@ -49,6 +51,8 @@ var defaultIDToken idTokenClaims = idTokenClaims{ "janed@me.com", "+4798765432", "http://mugbook.com/janed/me.jpg", + []string{"test:a", "test:b"}, + []string{"test:c", "test:d"}, jwt.StandardClaims{ Audience: "https://test.myapp.com", ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), @@ -65,6 +69,8 @@ var minimalIDToken idTokenClaims = idTokenClaims{ "", "", "", + []string{}, + []string{}, jwt.StandardClaims{ Audience: "https://test.myapp.com", ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), @@ -273,25 +279,39 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { const profileURLEmail = "janed@me.com" testCases := map[string]struct { - IDToken idTokenClaims - ExpectedUser string - ExpectedEmail string + IDToken idTokenClaims + GroupsClaim string + ExpectedUser string + ExpectedEmail string + ExpectedGroups []string }{ "Default IDToken": { - IDToken: defaultIDToken, - ExpectedUser: defaultIDToken.Subject, - ExpectedEmail: defaultIDToken.Email, + IDToken: defaultIDToken, + GroupsClaim: "groups", + ExpectedUser: defaultIDToken.Subject, + ExpectedEmail: defaultIDToken.Email, + ExpectedGroups: []string{"test:a", "test:b"}, }, "Minimal IDToken with no email claim": { - IDToken: minimalIDToken, - ExpectedUser: minimalIDToken.Subject, - ExpectedEmail: minimalIDToken.Subject, + IDToken: minimalIDToken, + GroupsClaim: "groups", + ExpectedUser: minimalIDToken.Subject, + ExpectedEmail: minimalIDToken.Subject, + ExpectedGroups: []string{}, + }, + "Custom Groups Claim": { + IDToken: defaultIDToken, + GroupsClaim: "other_groups", + ExpectedUser: defaultIDToken.Subject, + ExpectedEmail: defaultIDToken.Email, + ExpectedGroups: []string{"test:c", "test:d"}, }, } for testName, tc := range testCases { t.Run(testName, func(t *testing.T) { jsonResp := []byte(fmt.Sprintf(`{"email":"%s"}`, profileURLEmail)) server, provider := newTestSetup(jsonResp) + provider.GroupsClaim = tc.GroupsClaim defer server.Close() rawIDToken, err := newSignedTestIDToken(tc.IDToken) @@ -311,6 +331,7 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { assert.Equal(t, tc.ExpectedEmail, ss.Email) assert.Equal(t, rawIDToken, ss.IDToken) assert.Equal(t, rawIDToken, ss.AccessToken) + assert.Equal(t, tc.ExpectedGroups, ss.Groups) assert.Equal(t, "", ss.RefreshToken) }) }