1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-25 22:00:56 +02:00

Refactor GitLab to EnrichSessionState

This commit is contained in:
Nick Meves 2020-09-27 15:15:33 -07:00
parent e51f5fe7c9
commit 0da45e97e1
No known key found for this signature in database
GPG Key ID: 93BA8A3CEDCDD1CF
2 changed files with 23 additions and 79 deletions

View File

@ -3,7 +3,6 @@ package providers
import (
"context"
"fmt"
"strings"
"time"
oidc "github.com/coreos/go-oidc"
@ -168,20 +167,6 @@ func (p *GitLabProvider) verifyGroupMembership(userInfo *gitlabUserInfo) error {
return fmt.Errorf("user is not a member of '%s'", p.Groups)
}
func (p *GitLabProvider) verifyEmailDomain(userInfo *gitlabUserInfo) error {
if len(p.EmailDomains) == 0 || p.EmailDomains[0] == "*" {
return nil
}
for _, domain := range p.EmailDomains {
if strings.HasSuffix(userInfo.Email, domain) {
return nil
}
}
return fmt.Errorf("user email is not one of the valid domains '%v'", p.EmailDomains)
}
func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
@ -211,39 +196,27 @@ func (p *GitLabProvider) ValidateSessionState(ctx context.Context, s *sessions.S
}
// GetEmailAddress returns the Account email address
func (p *GitLabProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
func (p *GitLabProvider) EnrichSessionState(ctx context.Context, s *sessions.SessionState) error {
// Retrieve user info
userInfo, err := p.getUserInfo(ctx, s)
if err != nil {
return "", fmt.Errorf("failed to retrieve user info: %v", err)
return fmt.Errorf("failed to retrieve user info: %v", err)
}
// Check if email is verified
if !p.AllowUnverifiedEmail && !userInfo.EmailVerified {
return "", fmt.Errorf("user email is not verified")
}
// Check if email has valid domain
err = p.verifyEmailDomain(userInfo)
if err != nil {
return "", fmt.Errorf("email domain check failed: %v", err)
return fmt.Errorf("user email is not verified")
}
// Check group membership
// TODO (@NickMeves) - Refactor to Authorize
err = p.verifyGroupMembership(userInfo)
if err != nil {
return "", fmt.Errorf("group membership check failed: %v", err)
return fmt.Errorf("group membership check failed: %v", err)
}
return userInfo.Email, nil
}
s.User = userInfo.Username
s.Email = userInfo.Email
// GetUserName returns the Account user name
func (p *GitLabProvider) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) {
userInfo, err := p.getUserInfo(ctx, s)
if err != nil {
return "", fmt.Errorf("failed to retrieve user info: %v", err)
}
return userInfo.Username, nil
return nil
}

View File

@ -64,8 +64,8 @@ func TestGitLabProviderBadToken(t *testing.T) {
p := testGitLabProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"}
_, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
err := p.EnrichSessionState(context.Background(), session)
assert.Error(t, err)
}
func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) {
@ -76,8 +76,8 @@ func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) {
p := testGitLabProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
err := p.EnrichSessionState(context.Background(), session)
assert.Error(t, err)
}
func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) {
@ -89,9 +89,9 @@ func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) {
p.AllowUnverifiedEmail = true
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email)
err := p.EnrichSessionState(context.Background(), session)
assert.NoError(t, err)
assert.Equal(t, "foo@bar.com", session.Email)
}
func TestGitLabProviderUsername(t *testing.T) {
@ -103,9 +103,9 @@ func TestGitLabProviderUsername(t *testing.T) {
p.AllowUnverifiedEmail = true
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
username, err := p.GetUserName(context.Background(), session)
assert.Equal(t, nil, err)
assert.Equal(t, "FooBar", username)
err := p.EnrichSessionState(context.Background(), session)
assert.NoError(t, err)
assert.Equal(t, "FooBar", session.User)
}
func TestGitLabProviderGroupMembershipValid(t *testing.T) {
@ -118,9 +118,9 @@ func TestGitLabProviderGroupMembershipValid(t *testing.T) {
p.Groups = []string{"foo"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email)
err := p.EnrichSessionState(context.Background(), session)
assert.NoError(t, err)
assert.Equal(t, "FooBar", session.User)
}
func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
@ -133,35 +133,6 @@ func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
p.Groups = []string{"baz"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
}
func TestGitLabProviderEmailDomainValid(t *testing.T) {
b := testGitLabBackend()
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host)
p.AllowUnverifiedEmail = true
p.EmailDomains = []string{"bar.com"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email)
}
func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
b := testGitLabBackend()
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host)
p.AllowUnverifiedEmail = true
p.EmailDomains = []string{"baz.com"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
err := p.EnrichSessionState(context.Background(), session)
assert.Error(t, err)
}