1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-06 03:53:54 +02:00

Make claims list of strings

This commit is contained in:
Joel Speed 2020-10-03 18:57:25 +01:00
parent c9b3422801
commit 70990327d1
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
2 changed files with 33 additions and 26 deletions

View File

@ -8,7 +8,6 @@ import (
"io"
"io/ioutil"
"reflect"
"strings"
"time"
"unicode/utf8"
@ -70,31 +69,33 @@ func (s *SessionState) String() string {
return o + "}"
}
func (s *SessionState) GetClaim(claim string) string {
func (s *SessionState) GetClaim(claim string) []string {
if s == nil {
return ""
return []string{}
}
switch claim {
case "access_token":
return s.AccessToken
return []string{s.AccessToken}
case "id_token":
return s.IDToken
return []string{s.IDToken}
case "created_at":
return s.CreatedAt.String()
return []string{s.CreatedAt.String()}
case "expires_on":
return s.ExpiresOn.String()
return []string{s.ExpiresOn.String()}
case "refresh_token":
return s.RefreshToken
return []string{s.RefreshToken}
case "email":
return s.Email
return []string{s.Email}
case "user":
return s.User
return []string{s.User}
case "groups":
return strings.Join(s.Groups, ",")
groups := make([]string, len(s.Groups))
copy(groups, s.Groups)
return groups
case "preferred_username":
return s.PreferredUsername
return []string{s.PreferredUsername}
default:
return ""
return []string{}
}
}

View File

@ -85,28 +85,34 @@ func newClaimInjector(name string, source *options.ClaimSource) (valueInjector,
return nil, fmt.Errorf("error loading basicAuthPassword: %v", err)
}
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
claim := session.GetClaim(source.Claim)
if claim == "" {
return
claimValues := session.GetClaim(source.Claim)
for _, claim := range claimValues {
if claim == "" {
continue
}
auth := claim + ":" + string(password)
header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
}
auth := claim + ":" + string(password)
header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
}), nil
case source.Prefix != "":
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
claim := session.GetClaim(source.Claim)
if claim == "" {
return
claimValues := session.GetClaim(source.Claim)
for _, claim := range claimValues {
if claim == "" {
continue
}
header.Add(name, source.Prefix+claim)
}
header.Add(name, source.Prefix+claim)
}), nil
default:
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
claim := session.GetClaim(source.Claim)
if claim == "" {
return
claimValues := session.GetClaim(source.Claim)
for _, claim := range claimValues {
if claim == "" {
continue
}
header.Add(name, claim)
}
header.Add(name, claim)
}), nil
}
}