1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-05-13 22:06:40 +02:00

Support non-list and complex groups

This commit is contained in:
Nick Meves 2020-11-29 14:58:01 -08:00
parent eb56f24d6d
commit ea5b8cc21f
No known key found for this signature in database
GPG Key ID: 93BA8A3CEDCDD1CF
6 changed files with 166 additions and 36 deletions

View File

@ -48,6 +48,7 @@
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh) - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh)
- [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed) - [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed)
- [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves) - [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves)
- [#816](https://github.com/oauth2-proxy/oauth2-proxy/pull/816) (via [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936)) Support non-list group claims (@loafoe)
- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves) - [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves)
- [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed) - [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed)
- [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed) - [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"reflect"
"time" "time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
@ -59,7 +60,7 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta
} }
// Try to get missing emails or groups from a profileURL // Try to get missing emails or groups from a profileURL
if s.Email == "" || len(s.Groups) == 0 { if s.Email == "" || s.Groups == nil {
err := p.callProfileURL(ctx, s) err := p.callProfileURL(ctx, s)
if err != nil { if err != nil {
logger.Errorf("Warning: Profile URL request failed: %v", err) logger.Errorf("Warning: Profile URL request failed: %v", err)
@ -90,16 +91,15 @@ func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionSt
s.Email = email s.Email = email
} }
// Handle array & singleton groups cases
if len(s.Groups) == 0 { if len(s.Groups) == 0 {
groups, err := respJSON.Get(p.GroupsClaim).StringArray() for _, group := range coerceArray(respJSON, p.GroupsClaim) {
if err == nil { formatted, err := formatGroup(group)
s.Groups = groups if err != nil {
} else { logger.Errorf("Warning: unable to format group of type %s with error %s",
group, err := respJSON.Get(p.GroupsClaim).String() reflect.TypeOf(group), err)
if err == nil { continue
s.Groups = []string{group}
} }
s.Groups = append(s.Groups, formatted)
} }
} }

View File

@ -68,7 +68,7 @@ func newOIDCServer(body []byte) (*url.URL, *httptest.Server) {
return u, s return u, s
} }
func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) { func newTestOIDCSetup(body []byte) (*httptest.Server, *OIDCProvider) {
redeemURL, server := newOIDCServer(body) redeemURL, server := newOIDCServer(body)
provider := newOIDCProvider(redeemURL) provider := newOIDCProvider(redeemURL)
return server, provider return server, provider
@ -85,7 +85,7 @@ func TestOIDCProviderRedeem(t *testing.T) {
IDToken: idToken, IDToken: idToken,
}) })
server, provider := newTestSetup(body) server, provider := newTestOIDCSetup(body)
defer server.Close() defer server.Close()
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
@ -108,7 +108,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
IDToken: idToken, IDToken: idToken,
}) })
server, provider := newTestSetup(body) server, provider := newTestOIDCSetup(body)
provider.EmailClaim = "phone_number" provider.EmailClaim = "phone_number"
defer server.Close() defer server.Close()
@ -247,7 +247,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
ExistingSession: &sessions.SessionState{ ExistingSession: &sessions.SessionState{
User: "already", User: "already",
Email: "already@populated.com", Email: "already@populated.com",
Groups: []string{}, Groups: nil,
IDToken: idToken, IDToken: idToken,
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
@ -268,6 +268,89 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
RefreshToken: refreshToken, RefreshToken: refreshToken,
}, },
}, },
"Missing Groups with Complex Groups in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []map[string]interface{}{
{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Singleton Complex Group in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": map[string]interface{}{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Empty Groups Claims": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Custom Claim": { "Missing Groups with Custom Claim": {
ExistingSession: &sessions.SessionState{ ExistingSession: &sessions.SessionState{
User: "already", User: "already",
@ -297,7 +380,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
ExistingSession: &sessions.SessionState{ ExistingSession: &sessions.SessionState{
User: "already", User: "already",
Email: "already@populated.com", Email: "already@populated.com",
Groups: []string{}, Groups: nil,
IDToken: idToken, IDToken: idToken,
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
@ -346,7 +429,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
jsonResp, err := json.Marshal(tc.ProfileJSON) jsonResp, err := json.Marshal(tc.ProfileJSON)
assert.NoError(t, err) assert.NoError(t, err)
server, provider := newTestSetup(jsonResp) server, provider := newTestOIDCSetup(jsonResp)
provider.ProfileURL, err = url.Parse(server.URL) provider.ProfileURL, err = url.Parse(server.URL)
assert.NoError(t, err) assert.NoError(t, err)
@ -371,7 +454,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
RefreshToken: refreshToken, RefreshToken: refreshToken,
}) })
server, provider := newTestSetup(body) server, provider := newTestOIDCSetup(body)
defer server.Close() defer server.Close()
existingSession := &sessions.SessionState{ existingSession := &sessions.SessionState{
@ -405,7 +488,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
IDToken: idToken, IDToken: idToken,
}) })
server, provider := newTestSetup(body) server, provider := newTestOIDCSetup(body)
defer server.Close() defer server.Close()
existingSession := &sessions.SessionState{ existingSession := &sessions.SessionState{
@ -433,7 +516,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
GroupsClaim string GroupsClaim string
ExpectedUser string ExpectedUser string
ExpectedEmail string ExpectedEmail string
ExpectedGroups interface{} ExpectedGroups []string
}{ }{
"Default IDToken": { "Default IDToken": {
IDToken: defaultIDToken, IDToken: defaultIDToken,
@ -447,7 +530,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
GroupsClaim: "groups", GroupsClaim: "groups",
ExpectedUser: "123456789", ExpectedUser: "123456789",
ExpectedEmail: "123456789", ExpectedEmail: "123456789",
ExpectedGroups: []string{}, ExpectedGroups: nil,
}, },
"Custom Groups Claim": { "Custom Groups Claim": {
IDToken: defaultIDToken, IDToken: defaultIDToken,
@ -466,7 +549,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
} }
for testName, tc := range testCases { for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) { t.Run(testName, func(t *testing.T) {
server, provider := newTestSetup([]byte(`{}`)) server, provider := newTestOIDCSetup([]byte(`{}`))
provider.GroupsClaim = tc.GroupsClaim provider.GroupsClaim = tc.GroupsClaim
defer server.Close() defer server.Close()
@ -478,9 +561,9 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
assert.Equal(t, tc.ExpectedUser, ss.User) assert.Equal(t, tc.ExpectedUser, ss.User)
assert.Equal(t, tc.ExpectedEmail, ss.Email) assert.Equal(t, tc.ExpectedEmail, ss.Email)
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
assert.Equal(t, rawIDToken, ss.IDToken) assert.Equal(t, rawIDToken, ss.IDToken)
assert.Equal(t, rawIDToken, ss.AccessToken) assert.Equal(t, rawIDToken, ss.AccessToken)
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
assert.Equal(t, "", ss.RefreshToken) assert.Equal(t, "", ss.RefreshToken)
}) })
} }

View File

@ -189,12 +189,27 @@ func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
return claims, nil return claims, nil
} }
// extractGroups extracts groups from a claim to a list in a type safe manner // extractGroups extracts groups from a claim to a list in a type safe manner.
// If the claim isn't present, `nil` is returned. If the groups claim is
// present but empty, `[]string{}` is returned.
func (p *ProviderData) extractGroups(claims map[string]interface{}) []string { func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
rawClaim, ok := claims[p.GroupsClaim]
if !ok {
return nil
}
// Handle traditional list-based groups as well as non-standard singleton
// based groups. Both variants support complex objects if needed.
var claimGroups []interface{}
switch raw := rawClaim.(type) {
case []interface{}:
claimGroups = raw
case interface{}:
claimGroups = []interface{}{raw}
}
groups := []string{} groups := []string{}
rawGroups, ok := claims[p.GroupsClaim].([]interface{}) for _, rawGroup := range claimGroups {
if rawGroups != nil && ok {
for _, rawGroup := range rawGroups {
formattedGroup, err := formatGroup(rawGroup) formattedGroup, err := formatGroup(rawGroup)
if err != nil { if err != nil {
logger.Errorf("Warning: unable to format group of type %s with error %s", logger.Errorf("Warning: unable to format group of type %s with error %s",
@ -203,6 +218,5 @@ func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
} }
groups = append(groups, formattedGroup) groups = append(groups, formattedGroup)
} }
}
return groups return groups
} }

View File

@ -300,7 +300,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
ExpectedSession: &sessions.SessionState{ ExpectedSession: &sessions.SessionState{
User: "123456789", User: "123456789",
Email: "janed@me.com", Email: "janed@me.com",
Groups: []string{}, Groups: nil,
PreferredUsername: "Jane Dobbs", PreferredUsername: "Jane Dobbs",
}, },
}, },
@ -386,12 +386,20 @@ func TestProviderData_extractGroups(t *testing.T) {
"Just::A::String", "Just::A::String",
}, },
}, },
"Missing Groups": { "Missing Groups Claim Returns Nil": {
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"email": "this@does.not.matter.com", "email": "this@does.not.matter.com",
}, },
GroupsClaim: "groups", GroupsClaim: "groups",
ExpectedGroups: []string{}, ExpectedGroups: nil,
},
"Non List Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": "singleton",
},
GroupsClaim: "groups",
ExpectedGroups: []string{"singleton"},
}, },
} }
for testName, tc := range testCases { for testName, tc := range testCases {
@ -408,7 +416,11 @@ func TestProviderData_extractGroups(t *testing.T) {
provider.GroupsClaim = tc.GroupsClaim provider.GroupsClaim = tc.GroupsClaim
groups := provider.extractGroups(tc.Claims) groups := provider.extractGroups(tc.Claims)
if tc.ExpectedGroups != nil {
g.Expect(groups).To(Equal(tc.ExpectedGroups)) g.Expect(groups).To(Equal(tc.ExpectedGroups))
} else {
g.Expect(groups).To(BeNil())
}
}) })
} }
} }

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/go-simplejson"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -59,6 +60,8 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va
return a return a
} }
// getIDToken extracts an IDToken stored in the `Extra` fields of an
// oauth2.Token
func getIDToken(token *oauth2.Token) string { func getIDToken(token *oauth2.Token) string {
idToken, ok := token.Extra("id_token").(string) idToken, ok := token.Extra("id_token").(string)
if !ok { if !ok {
@ -67,6 +70,8 @@ func getIDToken(token *oauth2.Token) string {
return idToken return idToken
} }
// formatGroup coerces an OIDC groups claim into a string
// If it is non-string, marshal it into JSON.
func formatGroup(rawGroup interface{}) (string, error) { func formatGroup(rawGroup interface{}) (string, error) {
group, ok := rawGroup.(string) group, ok := rawGroup.(string)
if !ok { if !ok {
@ -78,3 +83,18 @@ func formatGroup(rawGroup interface{}) (string, error) {
} }
return group, nil return group, nil
} }
// coerceArray extracts a field from simplejson.Json that might be a
// singleton or a list and coerces it into a list.
func coerceArray(sj *simplejson.Json, key string) []interface{} {
array, err := sj.Get(key).Array()
if err == nil {
return array
}
single := sj.Get(key).Interface()
if single == nil {
return nil
}
return []interface{}{single}
}