From 70990327d187920c327cedfb847c0cecf70a0a4f Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 3 Oct 2020 18:57:25 +0100 Subject: [PATCH] Make claims list of strings --- pkg/apis/sessions/session_state.go | 27 +++++++++++++------------ pkg/header/injector.go | 32 ++++++++++++++++++------------ 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index 3f675135..03bc747a 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -8,7 +8,6 @@ import ( "io" "io/ioutil" "reflect" - "strings" "time" "unicode/utf8" @@ -70,31 +69,33 @@ func (s *SessionState) String() string { return o + "}" } -func (s *SessionState) GetClaim(claim string) string { +func (s *SessionState) GetClaim(claim string) []string { if s == nil { - return "" + return []string{} } switch claim { case "access_token": - return s.AccessToken + return []string{s.AccessToken} case "id_token": - return s.IDToken + return []string{s.IDToken} case "created_at": - return s.CreatedAt.String() + return []string{s.CreatedAt.String()} case "expires_on": - return s.ExpiresOn.String() + return []string{s.ExpiresOn.String()} case "refresh_token": - return s.RefreshToken + return []string{s.RefreshToken} case "email": - return s.Email + return []string{s.Email} case "user": - return s.User + return []string{s.User} case "groups": - return strings.Join(s.Groups, ",") + groups := make([]string, len(s.Groups)) + copy(groups, s.Groups) + return groups case "preferred_username": - return s.PreferredUsername + return []string{s.PreferredUsername} default: - return "" + return []string{} } } diff --git a/pkg/header/injector.go b/pkg/header/injector.go index 136185e5..9c6e2fcd 100644 --- a/pkg/header/injector.go +++ b/pkg/header/injector.go @@ -85,28 +85,34 @@ func newClaimInjector(name string, source *options.ClaimSource) (valueInjector, return nil, fmt.Errorf("error loading basicAuthPassword: %v", err) } return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { - claim := session.GetClaim(source.Claim) - if claim == "" { - return + claimValues := session.GetClaim(source.Claim) + for _, claim := range claimValues { + if claim == "" { + continue + } + auth := claim + ":" + string(password) + header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) } - auth := claim + ":" + string(password) - header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) }), nil case source.Prefix != "": return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { - claim := session.GetClaim(source.Claim) - if claim == "" { - return + claimValues := session.GetClaim(source.Claim) + for _, claim := range claimValues { + if claim == "" { + continue + } + header.Add(name, source.Prefix+claim) } - header.Add(name, source.Prefix+claim) }), nil default: return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { - claim := session.GetClaim(source.Claim) - if claim == "" { - return + claimValues := session.GetClaim(source.Claim) + for _, claim := range claimValues { + if claim == "" { + continue + } + header.Add(name, claim) } - header.Add(name, claim) }), nil } }