1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-23 21:50:48 +02:00

Integrate claim extractor into providers

This commit is contained in:
Joel Speed 2021-06-26 11:49:08 +01:00 committed by Joel Speed
parent 537e596904
commit 967051314e
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
15 changed files with 212 additions and 733 deletions

View File

@ -103,16 +103,17 @@ func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt
} }
func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error { func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error {
idToken, err := p.Verifier.Verify(ctx, s.IDToken) claims, err := p.getClaimExtractor(s.IDToken, s.AccessToken)
if err != nil { if err != nil {
return err return fmt.Errorf("could not extract claims: %v", err)
} }
claims, err := p.getClaims(idToken)
upn, found, err := claims.GetClaim(adfsUPNClaim)
if err != nil { if err != nil {
return fmt.Errorf("couldn't extract claims from id_token (%v)", err) return fmt.Errorf("could not extract %s claim: %v", adfsUPNClaim, err)
} }
upn := claims.raw[adfsUPNClaim]
if upn != nil { if found && fmt.Sprint(upn) != "" {
s.Email = fmt.Sprint(upn) s.Email = fmt.Sprint(upn)
} }
return nil return nil

View File

@ -79,7 +79,7 @@ func testADFSBackend() *httptest.Server {
{ {
"access_token": "my_access_token", "access_token": "my_access_token",
"id_token": "my_id_token", "id_token": "my_id_token",
"refresh_token": "my_refresh_token" "refresh_token": "my_refresh_token"
} }
` `
userInfo := ` userInfo := `
@ -150,9 +150,7 @@ var _ = Describe("ADFS Provider Tests", func() {
Context("with valid token", func() { Context("with valid token", func() {
It("should not throw an error", func() { It("should not throw an error", func() {
rawIDToken, _ := newSignedTestIDToken(defaultIDToken) rawIDToken, _ := newSignedTestIDToken(defaultIDToken)
idToken, err := p.Verifier.Verify(context.Background(), rawIDToken) session, err := p.buildSessionFromClaims(rawIDToken, "")
Expect(err).To(BeNil())
session, err := p.buildSessionFromClaims(idToken)
Expect(err).To(BeNil()) Expect(err).To(BeNil())
session.IDToken = rawIDToken session.IDToken = rawIDToken
err = p.EnrichSession(context.Background(), session) err = p.EnrichSession(context.Background(), session)

View File

@ -15,9 +15,20 @@ func CreateAuthorizedSession() *sessions.SessionState {
} }
func IsAuthorizedInHeader(reqHeader http.Header) bool { func IsAuthorizedInHeader(reqHeader http.Header) bool {
return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", authorizedAccessToken) return IsAuthorizedInHeaderWithToken(reqHeader, authorizedAccessToken)
}
func IsAuthorizedInHeaderWithToken(reqHeader http.Header, token string) bool {
return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", token)
} }
func IsAuthorizedInURL(reqURL *url.URL) bool { func IsAuthorizedInURL(reqURL *url.URL) bool {
return reqURL.Query().Get("access_token") == authorizedAccessToken return reqURL.Query().Get("access_token") == authorizedAccessToken
} }
func isAuthorizedRefreshInURLWithToken(reqURL *url.URL, token string) bool {
if token == "" {
return false
}
return reqURL.Query().Get("refresh_token") == token
}

View File

@ -78,6 +78,7 @@ func NewAzureProvider(p *ProviderData) *AzureProvider {
if p.ValidateURL == nil || p.ValidateURL.String() == "" { if p.ValidateURL == nil || p.ValidateURL.String() == "" {
p.ValidateURL = p.ProfileURL p.ValidateURL = p.ProfileURL
} }
p.getAuthorizationHeaderFunc = makeAzureHeader
return &AzureProvider{ return &AzureProvider{
ProviderData: p, ProviderData: p,
@ -150,7 +151,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
session.CreatedAtNow() session.CreatedAtNow()
session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken, session.AccessToken)
// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814 // https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117 // https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
@ -163,7 +164,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
} }
if session.Email == "" { if session.Email == "" {
email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken) email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken, session.AccessToken)
if err == nil && email != "" { if err == nil && email != "" {
session.Email = email session.Email = email
} else { } else {
@ -215,16 +216,16 @@ func (p *AzureProvider) prepareRedeem(redirectURL, code string) (url.Values, err
// verifyTokenAndExtractEmail tries to extract email claim from either id_token or access token // verifyTokenAndExtractEmail tries to extract email claim from either id_token or access token
// when oidc verifier is configured // when oidc verifier is configured
func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token string) (string, error) { func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, rawIDToken string, accessToken string) (string, error) {
email := "" email := ""
if token != "" && p.Verifier != nil { if rawIDToken != "" && p.Verifier != nil {
token, err := p.Verifier.Verify(ctx, token) _, err := p.Verifier.Verify(ctx, rawIDToken)
// due to issues mentioned above, id_token may not be signed by AAD // due to issues mentioned above, id_token may not be signed by AAD
if err == nil { if err == nil {
claims, err := p.getClaims(token) s, err := p.buildSessionFromClaims(rawIDToken, accessToken)
if err == nil { if err == nil {
email = claims.Email email = s.Email
} else { } else {
logger.Printf("unable to get claims from token: %v", err) logger.Printf("unable to get claims from token: %v", err)
} }
@ -287,7 +288,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
s.CreatedAtNow() s.CreatedAtNow()
s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken, s.AccessToken)
// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814 // https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117 // https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
@ -300,7 +301,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
} }
if s.Email == "" { if s.Email == "" {
email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken) email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken, s.AccessToken)
if err == nil && email != "" { if err == nil && email != "" {
s.Email = email s.Email = email
} else { } else {

View File

@ -13,9 +13,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
oidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc"
@ -145,11 +144,11 @@ func TestAzureSetTenant(t *testing.T) {
assert.Equal(t, "openid", p.Data().Scope) assert.Equal(t, "openid", p.Data().Scope)
} }
func testAzureBackend(payload string) *httptest.Server { func testAzureBackend(payload string, accessToken, refreshToken string) *httptest.Server {
return testAzureBackendWithError(payload, false) return testAzureBackendWithError(payload, accessToken, refreshToken, false)
} }
func testAzureBackendWithError(payload string, injectError bool) *httptest.Server { func testAzureBackendWithError(payload string, accessToken, refreshToken string, injectError bool) *httptest.Server {
path := "/v1.0/me" path := "/v1.0/me"
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
@ -163,7 +162,8 @@ func testAzureBackendWithError(payload string, injectError bool) *httptest.Serve
w.WriteHeader(200) w.WriteHeader(200)
} }
w.Write([]byte(payload)) w.Write([]byte(payload))
} else if !IsAuthorizedInHeader(r.Header) { } else if !IsAuthorizedInHeaderWithToken(r.Header, accessToken) &&
!isAuthorizedRefreshInURLWithToken(r.URL, refreshToken) {
w.WriteHeader(403) w.WriteHeader(403)
} else { } else {
w.WriteHeader(200) w.WriteHeader(200)
@ -224,7 +224,7 @@ func TestAzureProviderEnrichSession(t *testing.T) {
host string host string
) )
if testCase.PayloadFromAzureBackend != "" { if testCase.PayloadFromAzureBackend != "" {
b = testAzureBackend(testCase.PayloadFromAzureBackend) b = testAzureBackend(testCase.PayloadFromAzureBackend, authorizedAccessToken, "")
defer b.Close() defer b.Close()
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
@ -319,7 +319,7 @@ func TestAzureProviderRedeem(t *testing.T) {
payloadBytes, err := json.Marshal(payload) payloadBytes, err := json.Marshal(payload)
assert.NoError(t, err) assert.NoError(t, err)
b := testAzureBackendWithError(string(payloadBytes), testCase.InjectRedeemURLError) b := testAzureBackendWithError(string(payloadBytes), accessTokenString, testCase.RefreshToken, testCase.InjectRedeemURLError)
defer b.Close() defer b.Close()
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
@ -353,35 +353,44 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
func TestAzureProviderRefresh(t *testing.T) { func TestAzureProviderRefresh(t *testing.T) {
email := "foo@example.com" email := "foo@example.com"
subject := "foo"
idToken := idTokenClaims{ idToken := idTokenClaims{
StandardClaims: jwt.StandardClaims{Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532"}, Email: email,
Email: email} StandardClaims: jwt.StandardClaims{
Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532",
Subject: subject,
},
}
idTokenString, err := newSignedTestIDToken(idToken) idTokenString, err := newSignedTestIDToken(idToken)
assert.NoError(t, err) assert.NoError(t, err)
timestamp, err := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z") timestamp, err := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z")
assert.NoError(t, err) assert.NoError(t, err)
newAccessToken := "new_some_access_token"
payload := azureOAuthPayload{ payload := azureOAuthPayload{
IDToken: idTokenString, IDToken: idTokenString,
RefreshToken: "new_some_refresh_token", RefreshToken: "new_some_refresh_token",
AccessToken: "new_some_access_token", AccessToken: newAccessToken,
ExpiresOn: timestamp.Unix(), ExpiresOn: timestamp.Unix(),
} }
payloadBytes, err := json.Marshal(payload) payloadBytes, err := json.Marshal(payload)
assert.NoError(t, err) assert.NoError(t, err)
b := testAzureBackend(string(payloadBytes))
refreshToken := "some_refresh_token"
b := testAzureBackend(string(payloadBytes), newAccessToken, refreshToken)
defer b.Close() defer b.Close()
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
expires := time.Now().Add(time.Duration(-1) * time.Hour) expires := time.Now().Add(time.Duration(-1) * time.Hour)
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: refreshToken, IDToken: "some_id_token", ExpiresOn: &expires}
refreshed, err := p.RefreshSession(context.Background(), session) refreshed, err := p.RefreshSession(context.Background(), session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.True(t, refreshed) assert.True(t, refreshed)
assert.NotEqual(t, session, nil) assert.NotEqual(t, session, nil)
assert.Equal(t, "new_some_access_token", session.AccessToken) assert.Equal(t, newAccessToken, session.AccessToken)
assert.Equal(t, "new_some_refresh_token", session.RefreshToken) assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
assert.Equal(t, idTokenString, session.IDToken) assert.Equal(t, idTokenString, session.IDToken)
assert.Equal(t, email, session.Email) assert.Equal(t, email, session.Email)

View File

@ -57,6 +57,8 @@ func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider {
validateURL: digitalOceanDefaultProfileURL, validateURL: digitalOceanDefaultProfileURL,
scope: digitalOceanDefaultScope, scope: digitalOceanDefaultScope,
}) })
p.getAuthorizationHeaderFunc = makeOIDCHeader
return &DigitalOceanProvider{ProviderData: p} return &DigitalOceanProvider{ProviderData: p}
} }

View File

@ -58,6 +58,7 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider {
validateURL: facebookDefaultProfileURL, validateURL: facebookDefaultProfileURL,
scope: facebookDefaultScope, scope: facebookDefaultScope,
}) })
p.getAuthorizationHeaderFunc = makeOIDCHeader
return &FacebookProvider{ProviderData: p} return &FacebookProvider{ProviderData: p}
} }

View File

@ -65,6 +65,8 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
validateURL: linkedinDefaultValidateURL, validateURL: linkedinDefaultValidateURL,
scope: linkedinDefaultScope, scope: linkedinDefaultScope,
}) })
p.getAuthorizationHeaderFunc = makeLinkedInHeader
return &LinkedInProvider{ProviderData: p} return &LinkedInProvider{ProviderData: p}
} }

View File

@ -1,13 +1,5 @@
package providers package providers
import (
"context"
"fmt"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
)
// NextcloudProvider represents an Nextcloud based Identity Provider // NextcloudProvider represents an Nextcloud based Identity Provider
type NextcloudProvider struct { type NextcloudProvider struct {
*ProviderData *ProviderData
@ -20,20 +12,11 @@ const nextCloudProviderName = "Nextcloud"
// NewNextcloudProvider initiates a new NextcloudProvider // NewNextcloudProvider initiates a new NextcloudProvider
func NewNextcloudProvider(p *ProviderData) *NextcloudProvider { func NewNextcloudProvider(p *ProviderData) *NextcloudProvider {
p.ProviderName = nextCloudProviderName p.ProviderName = nextCloudProviderName
p.getAuthorizationHeaderFunc = makeOIDCHeader
if p.EmailClaim == OIDCEmailClaim {
// This implies the email claim has not been overridden, we should set a default
// for this provider
p.EmailClaim = "ocs.data.email"
}
return &NextcloudProvider{ProviderData: p} return &NextcloudProvider{ProviderData: p}
} }
// GetEmailAddress returns the Account email address
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
json, err := requests.New(p.ValidateURL.String()).
WithContext(ctx).
WithHeaders(makeOIDCHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return "", fmt.Errorf("error making request: %v", err)
}
email, err := json.Get("ocs").Get("data").Get("email").String()
return email, err
}

View File

@ -1,18 +1,13 @@
package providers package providers
import ( import (
"context"
"net/http"
"net/http/httptest"
"net/url" "net/url"
"testing" "testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
const formatJSON = "format=json" const formatJSON = "format=json"
const userPath = "/ocs/v2.php/cloud/user"
func testNextcloudProvider(hostname string) *NextcloudProvider { func testNextcloudProvider(hostname string) *NextcloudProvider {
p := NewNextcloudProvider( p := NewNextcloudProvider(
@ -32,23 +27,6 @@ func testNextcloudProvider(hostname string) *NextcloudProvider {
return p return p
} }
func testNextcloudBackend(payload string) *httptest.Server {
path := userPath
query := formatJSON
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != path || r.URL.RawQuery != query {
w.WriteHeader(404)
} else if !IsAuthorizedInHeader(r.Header) {
w.WriteHeader(403)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
}
func TestNextcloudProviderDefaults(t *testing.T) { func TestNextcloudProviderDefaults(t *testing.T) {
p := testNextcloudProvider("") p := testNextcloudProvider("")
assert.NotEqual(t, nil, p) assert.NotEqual(t, nil, p)
@ -87,53 +65,3 @@ func TestNextcloudProviderOverrides(t *testing.T) {
assert.Equal(t, "https://example.com/test/ocs/v2.php/cloud/user?"+formatJSON, assert.Equal(t, "https://example.com/test/ocs/v2.php/cloud/user?"+formatJSON,
p.Data().ValidateURL.String()) p.Data().ValidateURL.String())
} }
func TestNextcloudProviderGetEmailAddress(t *testing.T) {
b := testNextcloudBackend("{\"ocs\": {\"data\": { \"email\": \"michael.bland@gsa.gov\"}}}")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testNextcloudProvider(bURL.Host)
p.ValidateURL.Path = userPath
p.ValidateURL.RawQuery = formatJSON
session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
}
// Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse.
func TestNextcloudProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testNextcloudBackend("unused payload")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testNextcloudProvider(bURL.Host)
p.ValidateURL.Path = userPath
p.ValidateURL.RawQuery = formatJSON
// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}
func TestNextcloudProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testNextcloudBackend("{\"foo\": \"bar\"}")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testNextcloudProvider(bURL.Host)
p.ValidateURL.Path = userPath
p.ValidateURL.RawQuery = formatJSON
session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}

View File

@ -5,12 +5,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"reflect"
"time" "time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -24,6 +22,8 @@ type OIDCProvider struct {
// NewOIDCProvider initiates a new OIDCProvider // NewOIDCProvider initiates a new OIDCProvider
func NewOIDCProvider(p *ProviderData) *OIDCProvider { func NewOIDCProvider(p *ProviderData) *OIDCProvider {
p.ProviderName = "OpenID Connect" p.ProviderName = "OpenID Connect"
p.getAuthorizationHeaderFunc = makeOIDCHeader
return &OIDCProvider{ return &OIDCProvider{
ProviderData: p, ProviderData: p,
SkipNonce: true, SkipNonce: true,
@ -68,21 +68,6 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s
// EnrichSession is called after Redeem to allow providers to enrich session fields // EnrichSession is called after Redeem to allow providers to enrich session fields
// such as User, Email, Groups with provider specific API calls. // such as User, Email, Groups with provider specific API calls.
func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
if p.ProfileURL.String() == "" {
if s.Email == "" {
return errors.New("id_token did not contain an email and profileURL is not defined")
}
return nil
}
// Try to get missing emails or groups from a profileURL
if s.Email == "" || s.Groups == nil {
err := p.enrichFromProfileURL(ctx, s)
if err != nil {
logger.Errorf("Warning: Profile URL request failed: %v", err)
}
}
// If a mandatory email wasn't set, error at this point. // If a mandatory email wasn't set, error at this point.
if s.Email == "" { if s.Email == "" {
return errors.New("neither the id_token nor the profileURL set an email") return errors.New("neither the id_token nor the profileURL set an email")
@ -90,42 +75,9 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta
return nil return nil
} }
// enrichFromProfileURL enriches a session's Email & Groups via the JSON response of
// an OIDC profile URL
func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error {
respJSON, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(makeOIDCHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return err
}
email, err := respJSON.Get(p.EmailClaim).String()
if err == nil && s.Email == "" {
s.Email = email
}
if len(s.Groups) > 0 {
return nil
}
for _, group := range coerceArray(respJSON, p.GroupsClaim) {
formatted, err := formatGroup(group)
if err != nil {
logger.Errorf("Warning: unable to format group of type %s with error %s",
reflect.TypeOf(group), err)
continue
}
s.Groups = append(s.Groups, formatted)
}
return nil
}
// ValidateSession checks that the session's IDToken is still valid // ValidateSession checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
idToken, err := p.Verifier.Verify(ctx, s.IDToken) _, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil { if err != nil {
logger.Errorf("id_token verification failed: %v", err) logger.Errorf("id_token verification failed: %v", err)
return false return false
@ -134,7 +86,7 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
if p.SkipNonce { if p.SkipNonce {
return true return true
} }
err = p.checkNonce(s, idToken) err = p.checkNonce(s)
if err != nil { if err != nil {
logger.Errorf("nonce verification failed: %v", err) logger.Errorf("nonce verification failed: %v", err)
return false return false
@ -212,7 +164,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
return nil, err return nil, err
} }
ss, err := p.buildSessionFromClaims(idToken) ss, err := p.buildSessionFromClaims(token, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -235,7 +187,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
// createSession takes an oauth2.Token and creates a SessionState from it. // createSession takes an oauth2.Token and creates a SessionState from it.
// It alters behavior if called from Redeem vs Refresh // It alters behavior if called from Redeem vs Refresh
func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) {
idToken, err := p.verifyIDToken(ctx, token) _, err := p.verifyIDToken(ctx, token)
if err != nil { if err != nil {
switch err { switch err {
case ErrMissingIDToken: case ErrMissingIDToken:
@ -248,14 +200,15 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r
} }
} }
ss, err := p.buildSessionFromClaims(idToken) rawIDToken := getIDToken(token)
ss, err := p.buildSessionFromClaims(rawIDToken, token.AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ss.AccessToken = token.AccessToken ss.AccessToken = token.AccessToken
ss.RefreshToken = token.RefreshToken ss.RefreshToken = token.RefreshToken
ss.IDToken = getIDToken(token) ss.IDToken = rawIDToken
ss.CreatedAtNow() ss.CreatedAtNow()
ss.SetExpiresOn(token.Expiry) ss.SetExpiresOn(token.Expiry)

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -54,6 +53,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
Scope: "openid profile offline_access", Scope: "openid profile offline_access",
EmailClaim: "email", EmailClaim: "email",
GroupsClaim: "groups", GroupsClaim: "groups",
UserClaim: "sub",
Verifier: internaloidc.NewVerifier(oidc.NewVerifier( Verifier: internaloidc.NewVerifier(oidc.NewVerifier(
oidcIssuer, oidcIssuer,
mockJWKS{}, mockJWKS{},
@ -142,333 +142,6 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
assert.Equal(t, defaultIDToken.Phone, session.Email) assert.Equal(t, defaultIDToken.Phone, session.Email)
} }
func TestOIDCProvider_EnrichSession(t *testing.T) {
testCases := map[string]struct {
ExistingSession *sessions.SessionState
EmailClaim string
GroupsClaim string
ProfileJSON map[string]interface{}
ExpectedError error
ExpectedSession *sessions.SessionState
}{
"Already Populated": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "found@email.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Email: "found@email.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email Only in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "found@email.com",
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Email: "found@email.com",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email with Custom Claim": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "weird",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"weird": "weird@claim.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Email: "weird@claim.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email not in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"groups": []string{"new", "thing"},
},
ExpectedError: errors.New("neither the id_token nor the profileURL set an email"),
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"new", "thing"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Complex Groups in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []map[string]interface{}{
{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Singleton Complex Group in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": map[string]interface{}{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Empty Groups Claims": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Custom Claim": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "roles",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"roles": []string{"new", "thing", "roles"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"new", "thing", "roles"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups String Profile URL Response": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": "singleton",
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"singleton"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups in both Claims and Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
jsonResp, err := json.Marshal(tc.ProfileJSON)
assert.NoError(t, err)
server, provider := newTestOIDCSetup(jsonResp)
provider.ProfileURL, err = url.Parse(server.URL)
assert.NoError(t, err)
provider.EmailClaim = tc.EmailClaim
provider.GroupsClaim = tc.GroupsClaim
defer server.Close()
err = provider.EnrichSession(context.Background(), tc.ExistingSession)
assert.Equal(t, tc.ExpectedError, err)
assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession)
})
}
}
func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
idToken, _ := newSignedTestIDToken(defaultIDToken) idToken, _ := newSignedTestIDToken(defaultIDToken)
@ -565,11 +238,15 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
ExpectedGroups: []string{"test:c", "test:d"}, ExpectedGroups: []string{"test:c", "test:d"},
}, },
"Complex Groups Claim": { "Complex Groups Claim": {
IDToken: complexGroupsIDToken, IDToken: complexGroupsIDToken,
GroupsClaim: "groups", GroupsClaim: "groups",
ExpectedUser: "123456789", ExpectedUser: "123456789",
ExpectedEmail: "complex@claims.com", ExpectedEmail: "complex@claims.com",
ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, ExpectedGroups: []string{
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
"12345",
"Just::A::String",
},
}, },
} }
for testName, tc := range testCases { for testName, tc := range testCases {

View File

@ -5,20 +5,23 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http"
"net/url" "net/url"
"reflect"
"strings" "strings"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/util"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
const ( const (
OIDCEmailClaim = "email" OIDCEmailClaim = "email"
OIDCGroupsClaim = "groups" OIDCGroupsClaim = "groups"
// This is not exported as it's not currently user configurable
oidcUserClaim = "sub"
) )
var OIDCAudienceClaims = []string{"aud"} var OIDCAudienceClaims = []string{"aud"}
@ -52,6 +55,8 @@ type ProviderData struct {
// Universal Group authorization data structure // Universal Group authorization data structure
// any provider can set to consume // any provider can set to consume
AllowedGroups map[string]struct{} AllowedGroups map[string]struct{}
getAuthorizationHeaderFunc func(string) http.Header
} }
// Data returns the ProviderData // Data returns the ProviderData
@ -99,6 +104,10 @@ func (p *ProviderData) setProviderDefaults(defaults providerDefaults) {
if p.Scope == "" { if p.Scope == "" {
p.Scope = defaults.scope p.Scope = defaults.scope
} }
if p.UserClaim == "" {
p.UserClaim = oidcUserClaim
}
} }
// defaultURL will set return a default value if the given value is not set. // defaultURL will set return a default value if the given value is not set.
@ -120,17 +129,6 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL {
// OIDC compliant // OIDC compliant
// **************************************************************************** // ****************************************************************************
// OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload
type OIDCClaims struct {
Subject string `json:"sub"`
Email string `json:"-"`
Groups []string `json:"-"`
Verified *bool `json:"email_verified"`
Nonce string `json:"nonce"`
raw map[string]interface{}
}
func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
rawIDToken := getIDToken(token) rawIDToken := getIDToken(token)
if strings.TrimSpace(rawIDToken) == "" { if strings.TrimSpace(rawIDToken) == "" {
@ -144,110 +142,80 @@ func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (
// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
// with non-Token related fields. // with non-Token related fields.
func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) { func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (*sessions.SessionState, error) {
ss := &sessions.SessionState{} ss := &sessions.SessionState{}
if idToken == nil { if rawIDToken == "" {
return ss, nil return ss, nil
} }
claims, err := p.getClaims(idToken) extractor, err := p.getClaimExtractor(rawIDToken, accessToken)
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) return nil, err
} }
ss.User = claims.Subject // Use a slice of a struct (vs map) here in case the same claim is used twice
ss.Email = claims.Email for _, c := range []struct {
ss.Groups = claims.Groups claim string
dst interface{}
// Allow specialized providers that embed OIDCProvider to control the User }{
// claim. Not exposed as a configuration flag to generic OIDC provider {p.UserClaim, &ss.User},
// users (yet). {p.EmailClaim, &ss.Email},
if p.UserClaim != "" { {p.GroupsClaim, &ss.Groups},
user, ok := claims.raw[p.UserClaim].(string) // TODO (@NickMeves) Deprecate for dynamic claim to session mapping
if !ok { {"preferred_username", &ss.PreferredUsername},
return nil, fmt.Errorf("unable to extract custom UserClaim (%s)", p.UserClaim) } {
if _, err := extractor.GetClaimInto(c.claim, c.dst); err != nil {
return nil, err
} }
ss.User = user
}
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
if pref, ok := claims.raw["preferred_username"].(string); ok {
ss.PreferredUsername = pref
} }
// `email_verified` must be present and explicitly set to `false` to be // `email_verified` must be present and explicitly set to `false` to be
// considered unverified. // considered unverified.
verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail
if verifyEmail && claims.Verified != nil && !*claims.Verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) var verified bool
exists, err := extractor.GetClaimInto("email_verified", &verified)
if err != nil {
return nil, err
}
if verifyEmail && exists && !verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", ss.Email)
} }
return ss, nil return ss, nil
} }
// getClaims extracts IDToken claims into an OIDCClaims func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.ClaimExtractor, error) {
func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { extractor, err := util.NewClaimExtractor(context.TODO(), rawIDToken, p.ProfileURL, p.getAuthorizationHeader(accessToken))
claims := &OIDCClaims{} if err != nil {
return nil, fmt.Errorf("could not initialise claim extractor: %v", err)
// Extract default claims.
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse default id_token claims: %v", err)
}
// Extract custom claims.
if err := idToken.Claims(&claims.raw); err != nil {
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
} }
email := claims.raw[p.EmailClaim] return extractor, nil
if email != nil {
claims.Email = fmt.Sprint(email)
}
claims.Groups = p.extractGroups(claims.raw)
return claims, nil
} }
// checkNonce compares the session's nonce with the IDToken's nonce claim // checkNonce compares the session's nonce with the IDToken's nonce claim
func (p *ProviderData) checkNonce(s *sessions.SessionState, idToken *oidc.IDToken) error { func (p *ProviderData) checkNonce(s *sessions.SessionState) error {
claims, err := p.getClaims(idToken) extractor, err := p.getClaimExtractor(s.IDToken, "")
if err != nil { if err != nil {
return fmt.Errorf("id_token claims extraction failed: %v", err) return fmt.Errorf("id_token claims extraction failed: %v", err)
} }
if !s.CheckNonce(claims.Nonce) { var nonce string
if _, err := extractor.GetClaimInto("nonce", &nonce); err != nil {
return fmt.Errorf("could not extract nonce from ID Token: %v", err)
}
if !s.CheckNonce(nonce) {
return errors.New("id_token nonce claim does not match the session nonce") return errors.New("id_token nonce claim does not match the session nonce")
} }
return nil return nil
} }
// extractGroups extracts groups from a claim to a list in a type safe manner. func (p *ProviderData) getAuthorizationHeader(accessToken string) http.Header {
// If the claim isn't present, `nil` is returned. If the groups claim is if p.getAuthorizationHeaderFunc != nil && accessToken != "" {
// present but empty, `[]string{}` is returned. return p.getAuthorizationHeaderFunc(accessToken)
func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
rawClaim, ok := claims[p.GroupsClaim]
if !ok {
return nil
} }
return nil
// Handle traditional list-based groups as well as non-standard singleton
// based groups. Both variants support complex objects if needed.
var claimGroups []interface{}
switch raw := rawClaim.(type) {
case []interface{}:
claimGroups = raw
case interface{}:
claimGroups = []interface{}{raw}
}
groups := []string{}
for _, rawGroup := range claimGroups {
formattedGroup, err := formatGroup(rawGroup)
if err != nil {
logger.Errorf("Warning: unable to format group of type %s with error %s",
reflect.TypeOf(rawGroup), err)
continue
}
groups = append(groups, formattedGroup)
}
return groups
} }

View File

@ -60,16 +60,30 @@ var (
StandardClaims: standardClaims, StandardClaims: standardClaims,
} }
numericGroupsIDToken = idTokenClaims{
Name: "Jane Dobbs",
Email: "janed@me.com",
Phone: "+4798765432",
Picture: "http://mugbook.com/janed/me.jpg",
Groups: []interface{}{1, 2, 3},
Roles: []string{"test:c", "test:d"},
Verified: &verified,
Nonce: encryption.HashNonce([]byte(oidcNonce)),
StandardClaims: standardClaims,
}
complexGroupsIDToken = idTokenClaims{ complexGroupsIDToken = idTokenClaims{
Name: "Complex Claim", Name: "Complex Claim",
Email: "complex@claims.com", Email: "complex@claims.com",
Phone: "+5439871234", Phone: "+5439871234",
Picture: "http://mugbook.com/complex/claims.jpg", Picture: "http://mugbook.com/complex/claims.jpg",
Groups: []map[string]interface{}{ Groups: []interface{}{
{ map[string]interface{}{
"groupId": "Admin Group Id", "groupId": "Admin Group Id",
"roles": []string{"Admin"}, "roles": []string{"Admin"},
}, },
12345,
"Just::A::String",
}, },
Roles: []string{"test:simple", "test:roles"}, Roles: []string{"test:simple", "test:roles"},
Verified: &verified, Verified: &verified,
@ -228,6 +242,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: false, AllowUnverified: false,
EmailClaim: "email", EmailClaim: "email",
GroupsClaim: "groups", GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{ ExpectedSession: &sessions.SessionState{
User: "123456789", User: "123456789",
Email: "janed@me.com", Email: "janed@me.com",
@ -247,6 +262,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: true, AllowUnverified: true,
EmailClaim: "email", EmailClaim: "email",
GroupsClaim: "groups", GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{ ExpectedSession: &sessions.SessionState{
User: "123456789", User: "123456789",
Email: "unverified@email.com", Email: "unverified@email.com",
@ -259,10 +275,15 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: true, AllowUnverified: true,
EmailClaim: "email", EmailClaim: "email",
GroupsClaim: "groups", GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{ ExpectedSession: &sessions.SessionState{
User: "123456789", User: "123456789",
Email: "complex@claims.com", Email: "complex@claims.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, Groups: []string{
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
"12345",
"Just::A::String",
},
PreferredUsername: "Complex Claim", PreferredUsername: "Complex Claim",
}, },
}, },
@ -279,19 +300,25 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
PreferredUsername: "Jane Dobbs", PreferredUsername: "Jane Dobbs",
}, },
}, },
"User Claim Invalid": { "User Claim switched to non string": {
IDToken: defaultIDToken, IDToken: defaultIDToken,
AllowUnverified: true, AllowUnverified: true,
UserClaim: "groups", UserClaim: "roles",
EmailClaim: "email", EmailClaim: "email",
GroupsClaim: "groups", GroupsClaim: "groups",
ExpectedError: errors.New("unable to extract custom UserClaim (groups)"), ExpectedSession: &sessions.SessionState{
User: "[\"test:c\",\"test:d\"]",
Email: "janed@me.com",
Groups: []string{"test:a", "test:b"},
PreferredUsername: "Jane Dobbs",
},
}, },
"Email Claim Switched": { "Email Claim Switched": {
IDToken: unverifiedIDToken, IDToken: unverifiedIDToken,
AllowUnverified: true, AllowUnverified: true,
EmailClaim: "phone_number", EmailClaim: "phone_number",
GroupsClaim: "groups", GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{ ExpectedSession: &sessions.SessionState{
User: "123456789", User: "123456789",
Email: "+4025205729", Email: "+4025205729",
@ -304,9 +331,10 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: true, AllowUnverified: true,
EmailClaim: "roles", EmailClaim: "roles",
GroupsClaim: "groups", GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{ ExpectedSession: &sessions.SessionState{
User: "123456789", User: "123456789",
Email: "[test:c test:d]", Email: "[\"test:c\",\"test:d\"]",
Groups: []string{"test:a", "test:b"}, Groups: []string{"test:a", "test:b"},
PreferredUsername: "Mystery Man", PreferredUsername: "Mystery Man",
}, },
@ -316,6 +344,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: true, AllowUnverified: true,
EmailClaim: "aksjdfhjksadh", EmailClaim: "aksjdfhjksadh",
GroupsClaim: "groups", GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{ ExpectedSession: &sessions.SessionState{
User: "123456789", User: "123456789",
Email: "", Email: "",
@ -328,6 +357,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: false, AllowUnverified: false,
EmailClaim: "email", EmailClaim: "email",
GroupsClaim: "roles", GroupsClaim: "roles",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{ ExpectedSession: &sessions.SessionState{
User: "123456789", User: "123456789",
Email: "janed@me.com", Email: "janed@me.com",
@ -340,6 +370,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
AllowUnverified: false, AllowUnverified: false,
EmailClaim: "email", EmailClaim: "email",
GroupsClaim: "alskdjfsalkdjf", GroupsClaim: "alskdjfsalkdjf",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{ ExpectedSession: &sessions.SessionState{
User: "123456789", User: "123456789",
Email: "janed@me.com", Email: "janed@me.com",
@ -347,6 +378,32 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
PreferredUsername: "Jane Dobbs", PreferredUsername: "Jane Dobbs",
}, },
}, },
"Groups Claim Numeric values": {
IDToken: numericGroupsIDToken,
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "groups",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
Groups: []string{"1", "2", "3"},
PreferredUsername: "Jane Dobbs",
},
},
"Groups Claim string values": {
IDToken: defaultIDToken,
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "email",
UserClaim: "sub",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
Groups: []string{"janed@me.com"},
PreferredUsername: "Jane Dobbs",
},
},
} }
for testName, tc := range testCases { for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) { t.Run(testName, func(t *testing.T) {
@ -371,10 +428,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
rawIDToken, err := newSignedTestIDToken(tc.IDToken) rawIDToken, err := newSignedTestIDToken(tc.IDToken)
g.Expect(err).ToNot(HaveOccurred()) g.Expect(err).ToNot(HaveOccurred())
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken) ss, err := provider.buildSessionFromClaims(rawIDToken, "")
g.Expect(err).ToNot(HaveOccurred())
ss, err := provider.buildSessionFromClaims(idToken)
if err != nil { if err != nil {
g.Expect(err).To(Equal(tc.ExpectedError)) g.Expect(err).To(Equal(tc.ExpectedError))
} }
@ -418,6 +472,12 @@ func TestProviderData_checkNonce(t *testing.T) {
t.Run(testName, func(t *testing.T) { t.Run(testName, func(t *testing.T) {
g := NewWithT(t) g := NewWithT(t)
// Ensure that the ID token in the session is valid (signed and contains a nonce)
// as the nonce claim is extracted to compare with the session nonce
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
g.Expect(err).ToNot(HaveOccurred())
tc.Session.IDToken = rawIDToken
verificationOptions := &internaloidc.IDTokenVerificationOptions{ verificationOptions := &internaloidc.IDTokenVerificationOptions{
AudienceClaims: []string{"aud"}, AudienceClaims: []string{"aud"},
ClientID: oidcClientID, ClientID: oidcClientID,
@ -430,14 +490,7 @@ func TestProviderData_checkNonce(t *testing.T) {
), verificationOptions), ), verificationOptions),
} }
rawIDToken, err := newSignedTestIDToken(tc.IDToken) if err := provider.checkNonce(tc.Session); err != nil {
g.Expect(err).ToNot(HaveOccurred())
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
g.Expect(err).ToNot(HaveOccurred())
err = provider.checkNonce(tc.Session, idToken)
if err != nil {
g.Expect(err).To(Equal(tc.ExpectedError)) g.Expect(err).To(Equal(tc.ExpectedError))
} else { } else {
g.Expect(err).ToNot(HaveOccurred()) g.Expect(err).ToNot(HaveOccurred())
@ -445,95 +498,3 @@ func TestProviderData_checkNonce(t *testing.T) {
}) })
} }
} }
func TestProviderData_extractGroups(t *testing.T) {
testCases := map[string]struct {
Claims map[string]interface{}
GroupsClaim string
ExpectedGroups []string
}{
"Standard String Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": []interface{}{"three", "string", "groups"},
},
GroupsClaim: "groups",
ExpectedGroups: []string{"three", "string", "groups"},
},
"Different Claim Name": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"roles": []interface{}{"three", "string", "roles"},
},
GroupsClaim: "roles",
ExpectedGroups: []string{"three", "string", "roles"},
},
"Numeric Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": []interface{}{1, 2, 3},
},
GroupsClaim: "groups",
ExpectedGroups: []string{"1", "2", "3"},
},
"Complex Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": []interface{}{
map[string]interface{}{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
12345,
"Just::A::String",
},
},
GroupsClaim: "groups",
ExpectedGroups: []string{
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
"12345",
"Just::A::String",
},
},
"Missing Groups Claim Returns Nil": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
},
GroupsClaim: "groups",
ExpectedGroups: nil,
},
"Non List Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": "singleton",
},
GroupsClaim: "groups",
ExpectedGroups: []string{"singleton"},
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
g := NewWithT(t)
verificationOptions := &internaloidc.IDTokenVerificationOptions{
AudienceClaims: []string{"aud"},
ClientID: oidcClientID,
}
provider := &ProviderData{
Verifier: internaloidc.NewVerifier(oidc.NewVerifier(
oidcIssuer,
mockJWKS{},
&oidc.Config{ClientID: oidcClientID},
), verificationOptions),
}
provider.GroupsClaim = tc.GroupsClaim
groups := provider.extractGroups(tc.Claims)
if tc.ExpectedGroups != nil {
g.Expect(groups).To(Equal(tc.ExpectedGroups))
} else {
g.Expect(groups).To(BeNil())
}
})
}
}

View File

@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/go-simplejson"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -83,18 +82,3 @@ func formatGroup(rawGroup interface{}) (string, error) {
} }
return string(jsonGroup), nil return string(jsonGroup), nil
} }
// coerceArray extracts a field from simplejson.Json that might be a
// singleton or a list and coerces it into a list.
func coerceArray(sj *simplejson.Json, key string) []interface{} {
array, err := sj.Get(key).Array()
if err == nil {
return array
}
single := sj.Get(key).Interface()
if single == nil {
return nil
}
return []interface{}{single}
}