mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-04-23 12:18:50 +02:00
Preserve projects after RefreshSession
RefreshSession will override session.Groups with the new `groups` claims. We need to preserve all `project:` prefixed groups and reattach them post refresh.
This commit is contained in:
parent
11c2177f18
commit
95f9de5979
@ -16,6 +16,7 @@ const (
|
|||||||
gitlabProviderName = "GitLab"
|
gitlabProviderName = "GitLab"
|
||||||
gitlabDefaultScope = "openid email"
|
gitlabDefaultScope = "openid email"
|
||||||
gitlabUserClaim = "nickname"
|
gitlabUserClaim = "nickname"
|
||||||
|
gitlabProjectPrefix = "project:"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GitLabProvider represents a GitLab based Identity Provider
|
// GitLabProvider represents a GitLab based Identity Provider
|
||||||
@ -23,6 +24,8 @@ type GitLabProvider struct {
|
|||||||
*OIDCProvider
|
*OIDCProvider
|
||||||
|
|
||||||
allowedProjects []*gitlabProject
|
allowedProjects []*gitlabProject
|
||||||
|
// Expose this for unit testing
|
||||||
|
oidcRefreshFunc func(context.Context, *sessions.SessionState) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Provider = (*GitLabProvider)(nil)
|
var _ Provider = (*GitLabProvider)(nil)
|
||||||
@ -35,11 +38,14 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
|
|||||||
p.Scope = gitlabDefaultScope
|
p.Scope = gitlabDefaultScope
|
||||||
}
|
}
|
||||||
|
|
||||||
return &GitLabProvider{
|
oidcProvider := &OIDCProvider{
|
||||||
OIDCProvider: &OIDCProvider{
|
|
||||||
ProviderData: p,
|
ProviderData: p,
|
||||||
SkipNonce: false,
|
SkipNonce: false,
|
||||||
},
|
}
|
||||||
|
|
||||||
|
return &GitLabProvider{
|
||||||
|
OIDCProvider: oidcProvider,
|
||||||
|
oidcRefreshFunc: oidcProvider.RefreshSession,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,5 +251,41 @@ func (p *GitLabProvider) getProjectInfo(ctx context.Context, s *sessions.Session
|
|||||||
}
|
}
|
||||||
|
|
||||||
func formatProject(project *gitlabProject) string {
|
func formatProject(project *gitlabProject) string {
|
||||||
return fmt.Sprintf("project:%s", project.Name)
|
return gitlabProjectPrefix + project.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSession refreshes the session with the OIDCProvider implementation
|
||||||
|
// but preserves the custom GitLab projects added in the `EnrichSession` stage.
|
||||||
|
func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
|
projects := getSessionProjects(s)
|
||||||
|
// This will overwrite s.Groups with the new IDToken's `groups` claims
|
||||||
|
refreshed, err := p.oidcRefreshFunc(ctx, s)
|
||||||
|
if refreshed && err == nil {
|
||||||
|
s.Groups = append(s.Groups, projects...)
|
||||||
|
s.Groups = deduplicateGroups(s.Groups)
|
||||||
|
}
|
||||||
|
return refreshed, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSessionProjects(s *sessions.SessionState) []string {
|
||||||
|
var projects []string
|
||||||
|
for _, group := range s.Groups {
|
||||||
|
if strings.HasPrefix(group, gitlabProjectPrefix) {
|
||||||
|
projects = append(projects, group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return projects
|
||||||
|
}
|
||||||
|
|
||||||
|
func deduplicateGroups(groups []string) []string {
|
||||||
|
groupSet := make(map[string]struct{})
|
||||||
|
for _, group := range groups {
|
||||||
|
groupSet[group] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
uniqueGroups := make([]string, 0, len(groupSet))
|
||||||
|
for group := range groupSet {
|
||||||
|
uniqueGroups = append(uniqueGroups, group)
|
||||||
|
}
|
||||||
|
return uniqueGroups
|
||||||
}
|
}
|
||||||
|
@ -307,4 +307,55 @@ var _ = Describe("Gitlab Provider Tests", func() {
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("when refreshing", func() {
|
||||||
|
It("keeps existing projects after refreshing groups", func() {
|
||||||
|
session := &sessions.SessionState{}
|
||||||
|
session.Groups = []string{"foo", "bar", "project:thing", "project:sample"}
|
||||||
|
|
||||||
|
p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
|
s.Groups = []string{"baz"}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshed, err := p.RefreshSession(context.Background(), session)
|
||||||
|
Expect(refreshed).To(BeTrue())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(len(session.Groups)).To(Equal(3))
|
||||||
|
Expect(session.Groups).
|
||||||
|
To(ContainElements([]string{"baz", "project:thing", "project:sample"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("leaves existing groups when not refreshed", func() {
|
||||||
|
session := &sessions.SessionState{}
|
||||||
|
session.Groups = []string{"foo", "bar", "project:thing", "project:sample"}
|
||||||
|
|
||||||
|
p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshed, err := p.RefreshSession(context.Background(), session)
|
||||||
|
Expect(refreshed).To(BeFalse())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(len(session.Groups)).To(Equal(4))
|
||||||
|
Expect(session.Groups).
|
||||||
|
To(ContainElements([]string{"foo", "bar", "project:thing", "project:sample"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("leaves existing groups when OIDC refresh errors", func() {
|
||||||
|
session := &sessions.SessionState{}
|
||||||
|
session.Groups = []string{"foo", "bar", "project:thing", "project:sample"}
|
||||||
|
|
||||||
|
p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) {
|
||||||
|
return false, errors.New("failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshed, err := p.RefreshSession(context.Background(), session)
|
||||||
|
Expect(refreshed).To(BeFalse())
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(len(session.Groups)).To(Equal(4))
|
||||||
|
Expect(session.Groups).
|
||||||
|
To(ContainElements([]string{"foo", "bar", "project:thing", "project:sample"}))
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user