2020-07-26 04:50:18 +01:00
|
|
|
package header
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/base64"
|
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
|
|
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
|
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
|
|
|
|
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Injector interface {
|
|
|
|
Inject(http.Header, *sessionsapi.SessionState)
|
|
|
|
}
|
|
|
|
|
|
|
|
type injector struct {
|
|
|
|
valueInjectors []valueInjector
|
|
|
|
}
|
|
|
|
|
|
|
|
func (i injector) Inject(header http.Header, session *sessionsapi.SessionState) {
|
|
|
|
for _, injector := range i.valueInjectors {
|
|
|
|
injector.inject(header, session)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewInjector(headers []options.Header) (Injector, error) {
|
|
|
|
injectors := []valueInjector{}
|
|
|
|
for _, header := range headers {
|
|
|
|
for _, value := range header.Values {
|
|
|
|
injector, err := newValueinjector(header.Name, value)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error building injector for header %q: %v", header.Name, err)
|
|
|
|
}
|
|
|
|
injectors = append(injectors, injector)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return &injector{valueInjectors: injectors}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
type valueInjector interface {
|
|
|
|
inject(http.Header, *sessionsapi.SessionState)
|
|
|
|
}
|
|
|
|
|
|
|
|
func newValueinjector(name string, value options.HeaderValue) (valueInjector, error) {
|
|
|
|
switch {
|
|
|
|
case value.SecretSource != nil && value.ClaimSource == nil:
|
|
|
|
return newSecretInjector(name, value.SecretSource)
|
|
|
|
case value.SecretSource == nil && value.ClaimSource != nil:
|
|
|
|
return newClaimInjector(name, value.ClaimSource)
|
|
|
|
default:
|
|
|
|
return nil, fmt.Errorf("header %q value has multiple entries: only one entry per value is allowed", name)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type injectorFunc struct {
|
|
|
|
injectFunc func(http.Header, *sessionsapi.SessionState)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (i *injectorFunc) inject(header http.Header, session *sessionsapi.SessionState) {
|
|
|
|
i.injectFunc(header, session)
|
|
|
|
}
|
|
|
|
|
|
|
|
func newInjectorFunc(injectFunc func(header http.Header, session *sessionsapi.SessionState)) valueInjector {
|
|
|
|
return &injectorFunc{injectFunc: injectFunc}
|
|
|
|
}
|
|
|
|
|
|
|
|
func newSecretInjector(name string, source *options.SecretSource) (valueInjector, error) {
|
|
|
|
value, err := util.GetSecretValue(source)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error getting secret value: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
|
|
|
|
header.Add(name, string(value))
|
|
|
|
}), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func newClaimInjector(name string, source *options.ClaimSource) (valueInjector, error) {
|
|
|
|
switch {
|
|
|
|
case source.BasicAuthPassword != nil:
|
|
|
|
password, err := util.GetSecretValue(source.BasicAuthPassword)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error loading basicAuthPassword: %v", err)
|
|
|
|
}
|
|
|
|
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
|
2020-10-03 18:57:25 +01:00
|
|
|
claimValues := session.GetClaim(source.Claim)
|
|
|
|
for _, claim := range claimValues {
|
|
|
|
if claim == "" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
auth := claim + ":" + string(password)
|
|
|
|
header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
2020-07-26 04:50:18 +01:00
|
|
|
}
|
|
|
|
}), nil
|
|
|
|
case source.Prefix != "":
|
|
|
|
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
|
2020-10-03 18:57:25 +01:00
|
|
|
claimValues := session.GetClaim(source.Claim)
|
|
|
|
for _, claim := range claimValues {
|
|
|
|
if claim == "" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
header.Add(name, source.Prefix+claim)
|
2020-07-26 04:50:18 +01:00
|
|
|
}
|
|
|
|
}), nil
|
|
|
|
default:
|
|
|
|
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
|
2020-10-03 18:57:25 +01:00
|
|
|
claimValues := session.GetClaim(source.Claim)
|
|
|
|
for _, claim := range claimValues {
|
|
|
|
if claim == "" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
header.Add(name, claim)
|
2020-07-26 04:50:18 +01:00
|
|
|
}
|
|
|
|
}), nil
|
|
|
|
}
|
|
|
|
}
|