package header import ( "encoding/base64" "fmt" "net/http" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" ) type Injector interface { Inject(http.Header, *sessionsapi.SessionState) } type injector struct { valueInjectors []valueInjector } func (i injector) Inject(header http.Header, session *sessionsapi.SessionState) { for _, injector := range i.valueInjectors { injector.inject(header, session) } } func NewInjector(headers []options.Header) (Injector, error) { injectors := []valueInjector{} for _, header := range headers { for _, value := range header.Values { injector, err := newValueinjector(header.Name, value) if err != nil { return nil, fmt.Errorf("error building injector for header %q: %v", header.Name, err) } injectors = append(injectors, injector) } } return &injector{valueInjectors: injectors}, nil } type valueInjector interface { inject(http.Header, *sessionsapi.SessionState) } func newValueinjector(name string, value options.HeaderValue) (valueInjector, error) { switch { case value.SecretSource != nil && value.ClaimSource == nil: return newSecretInjector(name, value.SecretSource) case value.SecretSource == nil && value.ClaimSource != nil: return newClaimInjector(name, value.ClaimSource) default: return nil, fmt.Errorf("header %q value has multiple entries: only one entry per value is allowed", name) } } type injectorFunc struct { injectFunc func(http.Header, *sessionsapi.SessionState) } func (i *injectorFunc) inject(header http.Header, session *sessionsapi.SessionState) { i.injectFunc(header, session) } func newInjectorFunc(injectFunc func(header http.Header, session *sessionsapi.SessionState)) valueInjector { return &injectorFunc{injectFunc: injectFunc} } func newSecretInjector(name string, source *options.SecretSource) (valueInjector, error) { value, err := util.GetSecretValue(source) if err != nil { return nil, fmt.Errorf("error getting secret value: %v", err) } return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { header.Add(name, string(value)) }), nil } func newClaimInjector(name string, source *options.ClaimSource) (valueInjector, error) { switch { case source.BasicAuthPassword != nil: password, err := util.GetSecretValue(source.BasicAuthPassword) if err != nil { return nil, fmt.Errorf("error loading basicAuthPassword: %v", err) } return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { claimValues := session.GetClaim(source.Claim) for _, claim := range claimValues { if claim == "" { continue } auth := claim + ":" + string(password) header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) } }), nil case source.Prefix != "": return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { claimValues := session.GetClaim(source.Claim) for _, claim := range claimValues { if claim == "" { continue } header.Add(name, source.Prefix+claim) } }), nil default: return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { claimValues := session.GetClaim(source.Claim) for _, claim := range claimValues { if claim == "" { continue } header.Add(name, claim) } }), nil } }