1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-24 05:26:55 +02:00

Use global OIDC fields for Gitlab

This commit is contained in:
Nick Meves 2020-12-01 17:50:27 -08:00
parent 42f6cef7d6
commit d2ffef2c7e
No known key found for this signature in database
GPG Key ID: 93BA8A3CEDCDD1CF
5 changed files with 40 additions and 43 deletions

View File

@ -287,7 +287,6 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
msgs = append(msgs, "oidc provider requires an oidc issuer URL") msgs = append(msgs, "oidc provider requires an oidc issuer URL")
} }
case *providers.GitLabProvider: case *providers.GitLabProvider:
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
p.Groups = o.GitLabGroup p.Groups = o.GitLabGroup
err := p.AddProjects(o.GitlabProjects) err := p.AddProjects(o.GitlabProjects)
if err != nil { if err != nil {

View File

@ -20,8 +20,6 @@ type GitLabProvider struct {
Groups []string Groups []string
Projects []*GitlabProject Projects []*GitlabProject
AllowUnverifiedEmail bool
} }
// GitlabProject represents a Gitlab project constraint entity // GitlabProject represents a Gitlab project constraint entity
@ -103,7 +101,7 @@ func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) (
if err != nil { if err != nil {
return nil, fmt.Errorf("token exchange: %v", err) return nil, fmt.Errorf("token exchange: %v", err)
} }
s, err = p.createSessionState(ctx, token) s, err = p.createSession(ctx, token)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to update session: %v", err) return nil, fmt.Errorf("unable to update session: %v", err)
} }
@ -162,7 +160,7 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses
if err != nil { if err != nil {
return fmt.Errorf("failed to get token: %v", err) return fmt.Errorf("failed to get token: %v", err)
} }
newSession, err := p.createSessionState(ctx, token) newSession, err := p.createSession(ctx, token)
if err != nil { if err != nil {
return fmt.Errorf("unable to update session: %v", err) return fmt.Errorf("unable to update session: %v", err)
} }
@ -255,22 +253,21 @@ func (p *GitLabProvider) AddProjects(projects []string) error {
return nil return nil
} }
func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
rawIDToken, ok := token.Extra("id_token").(string) idToken, err := p.verifyIDToken(ctx, token)
if !ok {
return nil, fmt.Errorf("token response did not contain an id_token")
}
// Parse and verify ID Token payload.
idToken, err := p.Verifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not verify id_token: %v", err) switch err {
case ErrMissingIDToken:
return nil, fmt.Errorf("token response did not contain an id_token")
default:
return nil, fmt.Errorf("could not verify id_token: %v", err)
}
} }
created := time.Now() created := time.Now()
return &sessions.SessionState{ return &sessions.SessionState{
AccessToken: token.AccessToken, AccessToken: token.AccessToken,
IDToken: rawIDToken, IDToken: getIDToken(token),
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,
CreatedAt: &created, CreatedAt: &created,
ExpiresOn: &idToken.Expiry, ExpiresOn: &idToken.Expiry,

View File

@ -49,7 +49,7 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s
return p.createSession(ctx, token, false) return p.createSession(ctx, token, false)
} }
// EnrichSessionState is called after Redeem to allow providers to enrich session fields // EnrichSession is called after Redeem to allow providers to enrich session fields
// such as User, Email, Groups with provider specific API calls. // such as User, Email, Groups with provider specific API calls.
func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
if p.ProfileURL.String() == "" { if p.ProfileURL.String() == "" {
@ -61,7 +61,7 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta
// Try to get missing emails or groups from a profileURL // Try to get missing emails or groups from a profileURL
if s.Email == "" || s.Groups == nil { if s.Email == "" || s.Groups == nil {
err := p.callProfileURL(ctx, s) err := p.enrichFromProfileURL(ctx, s)
if err != nil { if err != nil {
logger.Errorf("Warning: Profile URL request failed: %v", err) logger.Errorf("Warning: Profile URL request failed: %v", err)
} }
@ -74,9 +74,9 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta
return nil return nil
} }
// callProfileURL enriches a session's Email & Groups via the JSON response of // enrichFromProfileURL enriches a session's Email & Groups via the JSON response of
// an OIDC profile URL // an OIDC profile URL
func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionState) error { func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error {
respJSON, err := requests.New(p.ProfileURL.String()). respJSON, err := requests.New(p.ProfileURL.String()).
WithContext(ctx). WithContext(ctx).
WithHeaders(makeOIDCHeader(s.AccessToken)). WithHeaders(makeOIDCHeader(s.AccessToken)).
@ -91,22 +91,23 @@ func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionSt
s.Email = email s.Email = email
} }
if len(s.Groups) == 0 { if len(s.Groups) > 0 {
for _, group := range coerceArray(respJSON, p.GroupsClaim) { return nil
formatted, err := formatGroup(group) }
if err != nil { for _, group := range coerceArray(respJSON, p.GroupsClaim) {
logger.Errorf("Warning: unable to format group of type %s with error %s", formatted, err := formatGroup(group)
reflect.TypeOf(group), err) if err != nil {
continue logger.Errorf("Warning: unable to format group of type %s with error %s",
} reflect.TypeOf(group), err)
s.Groups = append(s.Groups, formatted) continue
} }
s.Groups = append(s.Groups, formatted)
} }
return nil return nil
} }
// ValidateSessionState checks that the session's IDToken is still valid // ValidateSession checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
_, err := p.Verifier.Verify(ctx, s.IDToken) _, err := p.Verifier.Verify(ctx, s.IDToken)
return err == nil return err == nil

View File

@ -128,13 +128,13 @@ type OIDCClaims struct {
func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
rawIDToken := getIDToken(token) rawIDToken := getIDToken(token)
if strings.TrimSpace(rawIDToken) != "" { if strings.TrimSpace(rawIDToken) == "" {
if p.Verifier == nil { return nil, ErrMissingIDToken
return nil, ErrMissingOIDCVerifier
}
return p.Verifier.Verify(ctx, rawIDToken)
} }
return nil, ErrMissingIDToken if p.Verifier == nil {
return nil, ErrMissingOIDCVerifier
}
return p.Verifier.Verify(ctx, rawIDToken)
} }
// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState

View File

@ -73,15 +73,15 @@ func getIDToken(token *oauth2.Token) string {
// formatGroup coerces an OIDC groups claim into a string // formatGroup coerces an OIDC groups claim into a string
// If it is non-string, marshal it into JSON. // If it is non-string, marshal it into JSON.
func formatGroup(rawGroup interface{}) (string, error) { func formatGroup(rawGroup interface{}) (string, error) {
group, ok := rawGroup.(string) if group, ok := rawGroup.(string); ok {
if !ok { return group, nil
jsonGroup, err := json.Marshal(rawGroup)
if err != nil {
return "", err
}
group = string(jsonGroup)
} }
return group, nil
jsonGroup, err := json.Marshal(rawGroup)
if err != nil {
return "", err
}
return string(jsonGroup), nil
} }
// coerceArray extracts a field from simplejson.Json that might be a // coerceArray extracts a field from simplejson.Json that might be a