mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-04 23:37:29 +02:00
Refactor OIDC to EnrichSession
This commit is contained in:
parent
4fda907830
commit
a1877434b2
@ -274,7 +274,7 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
|||||||
p.SetRepository(o.BitbucketRepository)
|
p.SetRepository(o.BitbucketRepository)
|
||||||
case *providers.OIDCProvider:
|
case *providers.OIDCProvider:
|
||||||
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
|
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
|
||||||
p.UserIDClaim = o.UserIDClaim
|
p.EmailClaim = o.UserIDClaim
|
||||||
p.GroupsClaim = o.OIDCGroupsClaim
|
p.GroupsClaim = o.OIDCGroupsClaim
|
||||||
if p.Verifier == nil {
|
if p.Verifier == nil {
|
||||||
msgs = append(msgs, "oidc provider requires an oidc issuer URL")
|
msgs = append(msgs, "oidc provider requires an oidc issuer URL")
|
||||||
|
@ -2,18 +2,17 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
oidc "github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
"golang.org/x/oauth2"
|
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const emailClaim = "email"
|
const emailClaim = "email"
|
||||||
@ -23,7 +22,7 @@ type OIDCProvider struct {
|
|||||||
*ProviderData
|
*ProviderData
|
||||||
|
|
||||||
AllowUnverifiedEmail bool
|
AllowUnverifiedEmail bool
|
||||||
UserIDClaim string
|
EmailClaim string
|
||||||
GroupsClaim string
|
GroupsClaim string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,10 +35,10 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
|
|||||||
var _ Provider = (*OIDCProvider)(nil)
|
var _ Provider = (*OIDCProvider)(nil)
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c := oauth2.Config{
|
c := oauth2.Config{
|
||||||
@ -52,23 +51,74 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
}
|
}
|
||||||
token, err := c.Exchange(ctx, code)
|
token, err := c.Exchange(ctx, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("token exchange: %v", err)
|
return nil, fmt.Errorf("token exchange failure: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// in the initial exchange the id token is mandatory
|
return p.createSession(ctx, token, false)
|
||||||
idToken, err := p.findVerifiedIDToken(ctx, token)
|
}
|
||||||
|
|
||||||
|
// EnrichSessionState is called after Redeem to allow providers to enrich session fields
|
||||||
|
// such as User, Email, Groups with provider specific API calls.
|
||||||
|
func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
|
||||||
|
if p.ProfileURL.String() == "" {
|
||||||
|
if s.Email == "" {
|
||||||
|
return errors.New("id_token did not contain an email and profileURL is not defined")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get missing emails or groups from a profileURL
|
||||||
|
if s.Email == "" || len(s.Groups) == 0 {
|
||||||
|
err := p.callProfileURL(ctx, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
logger.Errorf("Warning: Profile URL request failed: %v", err)
|
||||||
} else if idToken == nil {
|
}
|
||||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err = p.createSessionState(ctx, token, idToken)
|
// If a mandatory email wasn't set, error at this point.
|
||||||
|
if s.Email == "" {
|
||||||
|
return errors.New("neither the id_token nor the profileURL set an email")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// callProfileURL enriches a session's Email & Groups via the JSON response of
|
||||||
|
// an OIDC profile URL
|
||||||
|
func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionState) error {
|
||||||
|
respJSON, err := requests.New(p.ProfileURL.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithHeaders(makeOIDCHeader(s.AccessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to update session: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
email, err := respJSON.Get(p.EmailClaim).String()
|
||||||
|
if err == nil && s.Email == "" {
|
||||||
|
s.Email = email
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle array & singleton groups cases
|
||||||
|
if len(s.Groups) == 0 {
|
||||||
|
groups, err := respJSON.Get(p.GroupsClaim).StringArray()
|
||||||
|
if err == nil {
|
||||||
|
s.Groups = groups
|
||||||
|
} else {
|
||||||
|
group, err := respJSON.Get(p.GroupsClaim).String()
|
||||||
|
if err == nil {
|
||||||
|
s.Groups = []string{group}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateSessionState checks that the session's IDToken is still valid
|
||||||
|
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||||
|
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||||
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
@ -83,14 +133,16 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
|
|||||||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("refreshed access token %s (expired on %s)\n", s, s.ExpiresOn)
|
logger.Printf("refreshed session: %s", s)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
|
||||||
|
// Access Token and (probably) the ID Token.
|
||||||
|
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c := oauth2.Config{
|
c := oauth2.Config{
|
||||||
@ -109,19 +161,14 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi
|
|||||||
return fmt.Errorf("failed to get token: %v", err)
|
return fmt.Errorf("failed to get token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// in the token refresh response the id_token is optional
|
newSession, err := p.createSession(ctx, token, true)
|
||||||
idToken, err := p.findVerifiedIDToken(ctx, token)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to extract id_token from response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
newSession, err := p.createSessionState(ctx, token, idToken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable create new session state from response: %v", err)
|
return fmt.Errorf("unable create new session state from response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// It's possible that if the refresh token isn't in the token response the session will not contain an id token
|
// It's possible that if the refresh token isn't in the token response the
|
||||||
// if it doesn't it's probably better to retain the old one
|
// session will not contain an id token.
|
||||||
|
// If it doesn't it's probably better to retain the old one
|
||||||
if newSession.IDToken != "" {
|
if newSession.IDToken != "" {
|
||||||
s.IDToken = newSession.IDToken
|
s.IDToken = newSession.IDToken
|
||||||
s.Email = newSession.Email
|
s.Email = newSession.Email
|
||||||
@ -135,102 +182,113 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi
|
|||||||
s.CreatedAt = newSession.CreatedAt
|
s.CreatedAt = newSession.CreatedAt
|
||||||
s.ExpiresOn = newSession.ExpiresOn
|
s.ExpiresOn = newSession.ExpiresOn
|
||||||
|
|
||||||
return
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
|
|
||||||
|
|
||||||
getIDToken := func() (string, bool) {
|
|
||||||
rawIDToken, _ := token.Extra("id_token").(string)
|
|
||||||
return rawIDToken, len(strings.TrimSpace(rawIDToken)) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if rawIDToken, present := getIDToken(); present {
|
|
||||||
verifiedIDToken, err := p.Verifier.Verify(ctx, rawIDToken)
|
|
||||||
return verifiedIDToken, err
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
|
||||||
|
|
||||||
var newSession *sessions.SessionState
|
|
||||||
|
|
||||||
if idToken == nil {
|
|
||||||
newSession = &sessions.SessionState{}
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
newSession, err = p.createSessionStateInternal(ctx, idToken, token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
created := time.Now()
|
|
||||||
newSession.AccessToken = token.AccessToken
|
|
||||||
newSession.RefreshToken = token.RefreshToken
|
|
||||||
newSession.CreatedAt = &created
|
|
||||||
newSession.ExpiresOn = &token.Expiry
|
|
||||||
return newSession, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
||||||
func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
|
func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
|
||||||
idToken, err := p.Verifier.Verify(ctx, token)
|
idToken, err := p.Verifier.Verify(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newSession, err := p.createSessionStateInternal(ctx, idToken, nil)
|
ss, err := p.buildSessionFromClaims(idToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newSession.AccessToken = token
|
// Allow empty Email in Bearer case since we can't hit the ProfileURL
|
||||||
newSession.IDToken = token
|
if ss.Email == "" {
|
||||||
newSession.RefreshToken = ""
|
ss.Email = ss.User
|
||||||
newSession.ExpiresOn = &idToken.Expiry
|
|
||||||
|
|
||||||
return newSession, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) {
|
|
||||||
|
|
||||||
newSession := &sessions.SessionState{}
|
|
||||||
|
|
||||||
if idToken == nil {
|
|
||||||
return newSession, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := p.findClaimsFromIDToken(ctx, idToken, token)
|
ss.AccessToken = token
|
||||||
|
ss.IDToken = token
|
||||||
|
ss.RefreshToken = ""
|
||||||
|
ss.ExpiresOn = &idToken.Expiry
|
||||||
|
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSession takes an oauth2.Token and creates a SessionState from it.
|
||||||
|
// It alters behavior if called from Redeem vs Refresh
|
||||||
|
func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) {
|
||||||
|
idToken, err := p.findVerifiedIDToken(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDToken is mandatory in Redeem but optional in Refresh
|
||||||
|
if idToken == nil && !refresh {
|
||||||
|
return nil, errors.New("token response did not contain an id_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
ss, err := p.buildSessionFromClaims(idToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ss.AccessToken = token.AccessToken
|
||||||
|
ss.RefreshToken = token.RefreshToken
|
||||||
|
ss.IDToken = getIDToken(token)
|
||||||
|
|
||||||
|
created := time.Now()
|
||||||
|
ss.CreatedAt = &created
|
||||||
|
ss.ExpiresOn = &token.Expiry
|
||||||
|
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
|
||||||
|
rawIDToken := getIDToken(token)
|
||||||
|
if strings.TrimSpace(rawIDToken) != "" {
|
||||||
|
return p.Verifier.Verify(ctx, rawIDToken)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
|
||||||
|
// with non-Token related fields.
|
||||||
|
func (p *OIDCProvider) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||||
|
ss := &sessions.SessionState{}
|
||||||
|
|
||||||
|
if idToken == nil {
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := p.getClaims(idToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
|
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if token != nil {
|
ss.User = claims.Subject
|
||||||
newSession.IDToken = token.Extra("id_token").(string)
|
ss.Email = claims.Email
|
||||||
|
ss.Groups = claims.Groups
|
||||||
|
|
||||||
|
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
|
||||||
|
if pref, ok := claims.rawClaims["preferred_username"].(string); ok {
|
||||||
|
ss.PreferredUsername = pref
|
||||||
}
|
}
|
||||||
|
|
||||||
newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future
|
verifyEmail := (p.EmailClaim == emailClaim) && !p.AllowUnverifiedEmail
|
||||||
|
|
||||||
newSession.User = claims.Subject
|
|
||||||
newSession.Groups = claims.Groups
|
|
||||||
newSession.PreferredUsername = claims.PreferredUsername
|
|
||||||
|
|
||||||
verifyEmail := (p.UserIDClaim == emailClaim) && !p.AllowUnverifiedEmail
|
|
||||||
if verifyEmail && claims.Verified != nil && !*claims.Verified {
|
if verifyEmail && claims.Verified != nil && !*claims.Verified {
|
||||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.UserID)
|
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newSession, nil
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState checks that the session's IDToken is still valid
|
type OIDCClaims struct {
|
||||||
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
Subject string `json:"sub"`
|
||||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
Email string `json:"-"`
|
||||||
return err == nil
|
Groups []string `json:"-"`
|
||||||
|
Verified *bool `json:"email_verified"`
|
||||||
|
|
||||||
|
rawClaims map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) {
|
// getClaims extracts IDToken claims into an OIDCClaims
|
||||||
|
func (p *OIDCProvider) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
|
||||||
claims := &OIDCClaims{}
|
claims := &OIDCClaims{}
|
||||||
|
|
||||||
// Extract default claims.
|
// Extract default claims.
|
||||||
@ -242,86 +300,28 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.
|
|||||||
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
|
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
userID := claims.rawClaims[p.UserIDClaim]
|
email := claims.rawClaims[p.EmailClaim]
|
||||||
if userID != nil {
|
if email != nil {
|
||||||
claims.UserID = fmt.Sprint(userID)
|
claims.Email = fmt.Sprint(email)
|
||||||
}
|
|
||||||
|
|
||||||
claims.Groups = p.extractGroupsFromRawClaims(claims.rawClaims)
|
|
||||||
|
|
||||||
// userID claim was not present or was empty in the ID Token
|
|
||||||
if claims.UserID == "" {
|
|
||||||
// BearerToken case, allow empty UserID
|
|
||||||
// ProfileURL checks below won't work since we don't have an access token
|
|
||||||
if token == nil {
|
|
||||||
claims.UserID = claims.Subject
|
|
||||||
return claims, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
profileURL := p.ProfileURL.String()
|
|
||||||
if profileURL == "" || token.AccessToken == "" {
|
|
||||||
return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
|
|
||||||
// contents at the profileURL contains the email.
|
|
||||||
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
|
|
||||||
respJSON, err := requests.New(profileURL).
|
|
||||||
WithContext(ctx).
|
|
||||||
WithHeaders(makeOIDCHeader(token.AccessToken)).
|
|
||||||
Do().
|
|
||||||
UnmarshalJSON()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userID, err := respJSON.Get(p.UserIDClaim).String()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("neither id_token nor userinfo endpoint contained user ID claim (%q)", p.UserIDClaim)
|
|
||||||
}
|
|
||||||
|
|
||||||
claims.UserID = userID
|
|
||||||
}
|
}
|
||||||
|
claims.Groups = p.extractGroups(claims.rawClaims)
|
||||||
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OIDCProvider) extractGroupsFromRawClaims(rawClaims map[string]interface{}) []string {
|
func (p *OIDCProvider) extractGroups(claims map[string]interface{}) []string {
|
||||||
groups := []string{}
|
groups := []string{}
|
||||||
|
rawGroups, ok := claims[p.GroupsClaim].([]interface{})
|
||||||
rawGroups, ok := rawClaims[p.GroupsClaim].([]interface{})
|
|
||||||
if rawGroups != nil && ok {
|
if rawGroups != nil && ok {
|
||||||
for _, rawGroup := range rawGroups {
|
for _, rawGroup := range rawGroups {
|
||||||
formattedGroup, err := formatGroup(rawGroup)
|
formattedGroup, err := formatGroup(rawGroup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("unable to format group of type %s with error %s", reflect.TypeOf(rawGroup), err)
|
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||||
|
reflect.TypeOf(rawGroup), err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
groups = append(groups, formattedGroup)
|
groups = append(groups, formattedGroup)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatGroup(rawGroup interface{}) (string, error) {
|
|
||||||
group, ok := rawGroup.(string)
|
|
||||||
if !ok {
|
|
||||||
jsonGroup, err := json.Marshal(rawGroup)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
group = string(jsonGroup)
|
|
||||||
}
|
|
||||||
return group, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type OIDCClaims struct {
|
|
||||||
rawClaims map[string]interface{}
|
|
||||||
UserID string
|
|
||||||
Subject string `json:"sub"`
|
|
||||||
Verified *bool `json:"email_verified"`
|
|
||||||
PreferredUsername string `json:"preferred_username"`
|
|
||||||
Groups []string `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
@ -154,7 +154,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
|
|||||||
|
|
||||||
p := &OIDCProvider{
|
p := &OIDCProvider{
|
||||||
ProviderData: providerData,
|
ProviderData: providerData,
|
||||||
UserIDClaim: "email",
|
EmailClaim: "email",
|
||||||
}
|
}
|
||||||
|
|
||||||
return p
|
return p
|
||||||
@ -225,7 +225,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
server, provider := newTestSetup(body)
|
server, provider := newTestSetup(body)
|
||||||
provider.UserIDClaim = "phone_number"
|
provider.EmailClaim = "phone_number"
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
|
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
|
||||||
@ -233,6 +233,256 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
|
|||||||
assert.Equal(t, defaultIDToken.Phone, session.Email)
|
assert.Equal(t, defaultIDToken.Phone, session.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOIDCProvider_EnrichSession(t *testing.T) {
|
||||||
|
const (
|
||||||
|
idToken = "Unchanged ID Token"
|
||||||
|
accessToken = "Unchanged Access Token"
|
||||||
|
refreshToken = "Unchanged Refresh Token"
|
||||||
|
)
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
ExistingSession *sessions.SessionState
|
||||||
|
EmailClaim string
|
||||||
|
GroupsClaim string
|
||||||
|
ProfileJSON map[string]interface{}
|
||||||
|
ExpectedError error
|
||||||
|
ExpectedSession *sessions.SessionState
|
||||||
|
}{
|
||||||
|
"Already Populated": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Email": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "found@email.com",
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Email: "found@email.com",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
"Missing Email Only in Profile URL": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "found@email.com",
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Email: "found@email.com",
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Email with Custom Claim": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "weird",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"weird": "weird@claim.com",
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Email: "weird@claim.com",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Email not in Profile URL": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: errors.New("neither the id_token nor the profileURL set an email"),
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"new", "thing"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups with Custom Claim": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: nil,
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "roles",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"roles": []string{"new", "thing", "roles"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"new", "thing", "roles"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups String Profile URL Response": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"groups": "singleton",
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"singleton"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups in both Claims and Profile URL": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for testName, tc := range testCases {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
jsonResp, err := json.Marshal(tc.ProfileJSON)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
server, provider := newTestSetup(jsonResp)
|
||||||
|
provider.ProfileURL, err = url.Parse(server.URL)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
provider.EmailClaim = tc.EmailClaim
|
||||||
|
provider.GroupsClaim = tc.GroupsClaim
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
err = provider.EnrichSession(context.Background(), tc.ExistingSession)
|
||||||
|
assert.Equal(t, tc.ExpectedError, err)
|
||||||
|
assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
||||||
|
|
||||||
idToken, _ := newSignedTestIDToken(defaultIDToken)
|
idToken, _ := newSignedTestIDToken(defaultIDToken)
|
||||||
@ -361,7 +611,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOIDCProvider_findVerifiedIdToken(t *testing.T) {
|
func TestOIDCProvider_findVerifiedIDToken(t *testing.T) {
|
||||||
|
|
||||||
server, provider := newTestSetup([]byte(""))
|
server, provider := newTestSetup([]byte(""))
|
||||||
|
|
||||||
@ -397,31 +647,3 @@ func TestOIDCProvider_findVerifiedIdToken(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, true, verifiedIDToken == nil)
|
assert.Equal(t, true, verifiedIDToken == nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_formatGroup(t *testing.T) {
|
|
||||||
testCases := map[string]struct {
|
|
||||||
RawGroup interface{}
|
|
||||||
ExpectedFormattedGroupValue string
|
|
||||||
}{
|
|
||||||
"String Group": {
|
|
||||||
RawGroup: "group",
|
|
||||||
ExpectedFormattedGroupValue: "group",
|
|
||||||
},
|
|
||||||
"Map Group": {
|
|
||||||
RawGroup: map[string]string{"id": "1", "name": "Test"},
|
|
||||||
ExpectedFormattedGroupValue: "{\"id\":\"1\",\"name\":\"Test\"}",
|
|
||||||
},
|
|
||||||
"List Group": {
|
|
||||||
RawGroup: []string{"First", "Second"},
|
|
||||||
ExpectedFormattedGroupValue: "[\"First\",\"Second\"]",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for testName, tc := range testCases {
|
|
||||||
t.Run(testName, func(t *testing.T) {
|
|
||||||
formattedGroup, err := formatGroup(tc.RawGroup)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, tc.ExpectedFormattedGroupValue, formattedGroup)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
package providers
|
package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -55,3 +58,23 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va
|
|||||||
a.RawQuery = params.Encode()
|
a.RawQuery = params.Encode()
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getIDToken(token *oauth2.Token) string {
|
||||||
|
idToken, ok := token.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return idToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatGroup(rawGroup interface{}) (string, error) {
|
||||||
|
group, ok := rawGroup.(string)
|
||||||
|
if !ok {
|
||||||
|
jsonGroup, err := json.Marshal(rawGroup)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
group = string(jsonGroup)
|
||||||
|
}
|
||||||
|
return group, nil
|
||||||
|
}
|
||||||
|
@ -5,9 +5,10 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMakeAuhtorizationHeader(t *testing.T) {
|
func Test_makeAuthorizationHeader(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
prefix string
|
prefix string
|
||||||
@ -64,3 +65,49 @@ func TestMakeAuhtorizationHeader(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_getIDToken(t *testing.T) {
|
||||||
|
const idToken = "eyJfoobar.eyJfoobar.12345asdf"
|
||||||
|
g := NewWithT(t)
|
||||||
|
|
||||||
|
token := &oauth2.Token{}
|
||||||
|
g.Expect(getIDToken(token)).To(Equal(""))
|
||||||
|
|
||||||
|
extraToken := token.WithExtra(map[string]interface{}{
|
||||||
|
"id_token": idToken,
|
||||||
|
})
|
||||||
|
g.Expect(getIDToken(extraToken)).To(Equal(idToken))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_formatGroup(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
rawGroup interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
"String Group": {
|
||||||
|
rawGroup: "group",
|
||||||
|
expected: "group",
|
||||||
|
},
|
||||||
|
"Numeric Group": {
|
||||||
|
rawGroup: 123,
|
||||||
|
expected: "123",
|
||||||
|
},
|
||||||
|
"Map Group": {
|
||||||
|
rawGroup: map[string]string{"id": "1", "name": "Test"},
|
||||||
|
expected: "{\"id\":\"1\",\"name\":\"Test\"}",
|
||||||
|
},
|
||||||
|
"List Group": {
|
||||||
|
rawGroup: []string{"First", "Second"},
|
||||||
|
expected: "[\"First\",\"Second\"]",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for testName, tc := range testCases {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
formattedGroup, err := formatGroup(tc.rawGroup)
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
g.Expect(formattedGroup).To(Equal(tc.expected))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user