1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-24 05:26:55 +02:00

Make claims list of strings

This commit is contained in:
Joel Speed 2020-10-03 18:57:25 +01:00
parent c9b3422801
commit 70990327d1
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
2 changed files with 33 additions and 26 deletions

View File

@ -8,7 +8,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"reflect" "reflect"
"strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
@ -70,31 +69,33 @@ func (s *SessionState) String() string {
return o + "}" return o + "}"
} }
func (s *SessionState) GetClaim(claim string) string { func (s *SessionState) GetClaim(claim string) []string {
if s == nil { if s == nil {
return "" return []string{}
} }
switch claim { switch claim {
case "access_token": case "access_token":
return s.AccessToken return []string{s.AccessToken}
case "id_token": case "id_token":
return s.IDToken return []string{s.IDToken}
case "created_at": case "created_at":
return s.CreatedAt.String() return []string{s.CreatedAt.String()}
case "expires_on": case "expires_on":
return s.ExpiresOn.String() return []string{s.ExpiresOn.String()}
case "refresh_token": case "refresh_token":
return s.RefreshToken return []string{s.RefreshToken}
case "email": case "email":
return s.Email return []string{s.Email}
case "user": case "user":
return s.User return []string{s.User}
case "groups": case "groups":
return strings.Join(s.Groups, ",") groups := make([]string, len(s.Groups))
copy(groups, s.Groups)
return groups
case "preferred_username": case "preferred_username":
return s.PreferredUsername return []string{s.PreferredUsername}
default: default:
return "" return []string{}
} }
} }

View File

@ -85,28 +85,34 @@ func newClaimInjector(name string, source *options.ClaimSource) (valueInjector,
return nil, fmt.Errorf("error loading basicAuthPassword: %v", err) return nil, fmt.Errorf("error loading basicAuthPassword: %v", err)
} }
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
claim := session.GetClaim(source.Claim) claimValues := session.GetClaim(source.Claim)
for _, claim := range claimValues {
if claim == "" { if claim == "" {
return continue
} }
auth := claim + ":" + string(password) auth := claim + ":" + string(password)
header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
}
}), nil }), nil
case source.Prefix != "": case source.Prefix != "":
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
claim := session.GetClaim(source.Claim) claimValues := session.GetClaim(source.Claim)
for _, claim := range claimValues {
if claim == "" { if claim == "" {
return continue
} }
header.Add(name, source.Prefix+claim) header.Add(name, source.Prefix+claim)
}
}), nil }), nil
default: default:
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
claim := session.GetClaim(source.Claim) claimValues := session.GetClaim(source.Claim)
for _, claim := range claimValues {
if claim == "" { if claim == "" {
return continue
} }
header.Add(name, claim) header.Add(name, claim)
}
}), nil }), nil
} }
} }