From 95f9de5979e897fe7a61dc439245fade44f20ed9 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 19 Jun 2021 15:18:50 -0700 Subject: [PATCH] Preserve projects after `RefreshSession` RefreshSession will override session.Groups with the new `groups` claims. We need to preserve all `project:` prefixed groups and reattach them post refresh. --- providers/gitlab.go | 58 ++++++++++++++++++++++++++++++++++------ providers/gitlab_test.go | 51 +++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 8 deletions(-) diff --git a/providers/gitlab.go b/providers/gitlab.go index af0bcdae..a90669f6 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -13,9 +13,10 @@ import ( ) const ( - gitlabProviderName = "GitLab" - gitlabDefaultScope = "openid email" - gitlabUserClaim = "nickname" + gitlabProviderName = "GitLab" + gitlabDefaultScope = "openid email" + gitlabUserClaim = "nickname" + gitlabProjectPrefix = "project:" ) // GitLabProvider represents a GitLab based Identity Provider @@ -23,6 +24,8 @@ type GitLabProvider struct { *OIDCProvider allowedProjects []*gitlabProject + // Expose this for unit testing + oidcRefreshFunc func(context.Context, *sessions.SessionState) (bool, error) } var _ Provider = (*GitLabProvider)(nil) @@ -35,11 +38,14 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider { p.Scope = gitlabDefaultScope } + oidcProvider := &OIDCProvider{ + ProviderData: p, + SkipNonce: false, + } + return &GitLabProvider{ - OIDCProvider: &OIDCProvider{ - ProviderData: p, - SkipNonce: false, - }, + OIDCProvider: oidcProvider, + oidcRefreshFunc: oidcProvider.RefreshSession, } } @@ -245,5 +251,41 @@ func (p *GitLabProvider) getProjectInfo(ctx context.Context, s *sessions.Session } func formatProject(project *gitlabProject) string { - return fmt.Sprintf("project:%s", project.Name) + return gitlabProjectPrefix + project.Name +} + +// RefreshSession refreshes the session with the OIDCProvider implementation +// but preserves the custom GitLab projects added in the `EnrichSession` stage. +func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + projects := getSessionProjects(s) + // This will overwrite s.Groups with the new IDToken's `groups` claims + refreshed, err := p.oidcRefreshFunc(ctx, s) + if refreshed && err == nil { + s.Groups = append(s.Groups, projects...) + s.Groups = deduplicateGroups(s.Groups) + } + return refreshed, err +} + +func getSessionProjects(s *sessions.SessionState) []string { + var projects []string + for _, group := range s.Groups { + if strings.HasPrefix(group, gitlabProjectPrefix) { + projects = append(projects, group) + } + } + return projects +} + +func deduplicateGroups(groups []string) []string { + groupSet := make(map[string]struct{}) + for _, group := range groups { + groupSet[group] = struct{}{} + } + + uniqueGroups := make([]string, 0, len(groupSet)) + for group := range groupSet { + uniqueGroups = append(uniqueGroups, group) + } + return uniqueGroups } diff --git a/providers/gitlab_test.go b/providers/gitlab_test.go index 699ebb0f..cad77ea2 100644 --- a/providers/gitlab_test.go +++ b/providers/gitlab_test.go @@ -307,4 +307,55 @@ var _ = Describe("Gitlab Provider Tests", func() { }), ) }) + + Context("when refreshing", func() { + It("keeps existing projects after refreshing groups", func() { + session := &sessions.SessionState{} + session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} + + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + s.Groups = []string{"baz"} + return true, nil + } + + refreshed, err := p.RefreshSession(context.Background(), session) + Expect(refreshed).To(BeTrue()) + Expect(err).ToNot(HaveOccurred()) + Expect(len(session.Groups)).To(Equal(3)) + Expect(session.Groups). + To(ContainElements([]string{"baz", "project:thing", "project:sample"})) + }) + + It("leaves existing groups when not refreshed", func() { + session := &sessions.SessionState{} + session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} + + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + return false, nil + } + + refreshed, err := p.RefreshSession(context.Background(), session) + Expect(refreshed).To(BeFalse()) + Expect(err).ToNot(HaveOccurred()) + Expect(len(session.Groups)).To(Equal(4)) + Expect(session.Groups). + To(ContainElements([]string{"foo", "bar", "project:thing", "project:sample"})) + }) + + It("leaves existing groups when OIDC refresh errors", func() { + session := &sessions.SessionState{} + session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} + + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + return false, errors.New("failure") + } + + refreshed, err := p.RefreshSession(context.Background(), session) + Expect(refreshed).To(BeFalse()) + Expect(err).To(HaveOccurred()) + Expect(len(session.Groups)).To(Equal(4)) + Expect(session.Groups). + To(ContainElements([]string{"foo", "bar", "project:thing", "project:sample"})) + }) + }) })