mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-01-06 03:53:54 +02:00
Make claims list of strings
This commit is contained in:
parent
c9b3422801
commit
70990327d1
@ -8,7 +8,6 @@ import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
@ -70,31 +69,33 @@ func (s *SessionState) String() string {
|
||||
return o + "}"
|
||||
}
|
||||
|
||||
func (s *SessionState) GetClaim(claim string) string {
|
||||
func (s *SessionState) GetClaim(claim string) []string {
|
||||
if s == nil {
|
||||
return ""
|
||||
return []string{}
|
||||
}
|
||||
switch claim {
|
||||
case "access_token":
|
||||
return s.AccessToken
|
||||
return []string{s.AccessToken}
|
||||
case "id_token":
|
||||
return s.IDToken
|
||||
return []string{s.IDToken}
|
||||
case "created_at":
|
||||
return s.CreatedAt.String()
|
||||
return []string{s.CreatedAt.String()}
|
||||
case "expires_on":
|
||||
return s.ExpiresOn.String()
|
||||
return []string{s.ExpiresOn.String()}
|
||||
case "refresh_token":
|
||||
return s.RefreshToken
|
||||
return []string{s.RefreshToken}
|
||||
case "email":
|
||||
return s.Email
|
||||
return []string{s.Email}
|
||||
case "user":
|
||||
return s.User
|
||||
return []string{s.User}
|
||||
case "groups":
|
||||
return strings.Join(s.Groups, ",")
|
||||
groups := make([]string, len(s.Groups))
|
||||
copy(groups, s.Groups)
|
||||
return groups
|
||||
case "preferred_username":
|
||||
return s.PreferredUsername
|
||||
return []string{s.PreferredUsername}
|
||||
default:
|
||||
return ""
|
||||
return []string{}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -85,28 +85,34 @@ func newClaimInjector(name string, source *options.ClaimSource) (valueInjector,
|
||||
return nil, fmt.Errorf("error loading basicAuthPassword: %v", err)
|
||||
}
|
||||
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
|
||||
claim := session.GetClaim(source.Claim)
|
||||
if claim == "" {
|
||||
return
|
||||
claimValues := session.GetClaim(source.Claim)
|
||||
for _, claim := range claimValues {
|
||||
if claim == "" {
|
||||
continue
|
||||
}
|
||||
auth := claim + ":" + string(password)
|
||||
header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
||||
}
|
||||
auth := claim + ":" + string(password)
|
||||
header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
||||
}), nil
|
||||
case source.Prefix != "":
|
||||
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
|
||||
claim := session.GetClaim(source.Claim)
|
||||
if claim == "" {
|
||||
return
|
||||
claimValues := session.GetClaim(source.Claim)
|
||||
for _, claim := range claimValues {
|
||||
if claim == "" {
|
||||
continue
|
||||
}
|
||||
header.Add(name, source.Prefix+claim)
|
||||
}
|
||||
header.Add(name, source.Prefix+claim)
|
||||
}), nil
|
||||
default:
|
||||
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
|
||||
claim := session.GetClaim(source.Claim)
|
||||
if claim == "" {
|
||||
return
|
||||
claimValues := session.GetClaim(source.Claim)
|
||||
for _, claim := range claimValues {
|
||||
if claim == "" {
|
||||
continue
|
||||
}
|
||||
header.Add(name, claim)
|
||||
}
|
||||
header.Add(name, claim)
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user