diff --git a/providers/oidc.go b/providers/oidc.go index acf48f55..8f2f02b9 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -157,7 +157,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok newSession = &sessions.SessionState{} } else { var err error - newSession, err = p.createSessionStateInternal(ctx, token.Extra("id_token").(string), idToken, token, false) + newSession, err = p.createSessionStateInternal(ctx, idToken, token) if err != nil { return nil, err } @@ -172,7 +172,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok } func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { - newSession, err := p.createSessionStateInternal(ctx, rawIDToken, idToken, nil, true) + newSession, err := p.createSessionStateInternal(ctx, idToken, nil) if err != nil { return nil, err } @@ -185,24 +185,22 @@ func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, ra return newSession, nil } -func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, rawIDToken string, idToken *oidc.IDToken, token *oauth2.Token, bearer bool) (*sessions.SessionState, error) { +func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) { newSession := &sessions.SessionState{} if idToken == nil { return newSession, nil } - accessToken := "" - if token != nil { - accessToken = token.AccessToken - } - claims, err := p.findClaimsFromIDToken(ctx, idToken, accessToken, p.ProfileURL.String(), bearer) + claims, err := p.findClaimsFromIDToken(ctx, idToken, token) if err != nil { return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) } - newSession.IDToken = rawIDToken + if token != nil { + newSession.IDToken = token.Extra("id_token").(string) + } newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future @@ -230,7 +228,7 @@ func getOIDCHeader(accessToken string) http.Header { return header } -func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, accessToken string, profileURL string, bearer bool) (*OIDCClaims, error) { +func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { claims := &OIDCClaims{} // Extract default claims. if err := idToken.Claims(&claims); err != nil { @@ -248,11 +246,15 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. // userID claim was not present or was empty in the ID Token if claims.UserID == "" { - if profileURL == "" { - if bearer { - claims.UserID = claims.Subject - return claims, nil - } + // BearerToken case, allow empty UserID + // ProfileURL checks below won't work since we don't have an access token + if token == nil { + claims.UserID = claims.Subject + return claims, nil + } + + profileURL := p.ProfileURL.String() + if profileURL == "" || token.AccessToken == "" { return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim) } @@ -261,7 +263,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. // Make a query to the userinfo endpoint, and attempt to locate the email from there. respJSON, err := requests.New(profileURL). WithContext(ctx). - WithHeaders(getOIDCHeader(accessToken)). + WithHeaders(getOIDCHeader(token.AccessToken)). Do(). UnmarshalJSON() if err != nil { diff --git a/providers/oidc_test.go b/providers/oidc_test.go index b0268138..9e96752d 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -274,23 +274,18 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { testCases := map[string]struct { IDToken idTokenClaims - ProfileURL bool + ExpectedUser string ExpectedEmail string }{ "Default IDToken": { IDToken: defaultIDToken, - ProfileURL: true, - ExpectedEmail: profileURLEmail, + ExpectedUser: defaultIDToken.Subject, + ExpectedEmail: defaultIDToken.Email, }, - "Minimal IDToken with no OIDC Profile URL": { + "Minimal IDToken with no email claim": { IDToken: minimalIDToken, - ProfileURL: false, - ExpectedEmail: "", - }, - "Minimal IDToken with OIDC Profile URL": { - IDToken: minimalIDToken, - ProfileURL: true, - ExpectedEmail: profileURLEmail, + ExpectedUser: minimalIDToken.Subject, + ExpectedEmail: minimalIDToken.Subject, }, } for testName, tc := range testCases { @@ -298,9 +293,6 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { jsonResp := []byte(fmt.Sprintf(`{"email":"%s"}`, profileURLEmail)) server, provider := newTestSetup(jsonResp) defer server.Close() - if !tc.ProfileURL { - provider.ProfileURL = &url.URL{} - } rawIDToken, err := newSignedTestIDToken(tc.IDToken) assert.NoError(t, err) @@ -315,13 +307,8 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { ss, err := provider.CreateSessionStateFromBearerToken(context.Background(), rawIDToken, idToken) assert.NoError(t, err) - if tc.ExpectedEmail != "" { - assert.Equal(t, tc.ExpectedEmail, ss.Email) - assert.NotEqual(t, ss.Email, ss.User) - } else { - assert.Equal(t, tc.IDToken.Subject, ss.Email) - assert.Equal(t, ss.Email, ss.User) - } + assert.Equal(t, tc.ExpectedUser, ss.User) + assert.Equal(t, tc.ExpectedEmail, ss.Email) assert.Equal(t, rawIDToken, ss.IDToken) assert.Equal(t, rawIDToken, ss.AccessToken) assert.Equal(t, "", ss.RefreshToken)