1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-10 04:18:14 +02:00
oauth2-proxy/pkg/header/injector.go
2020-10-07 18:25:00 +01:00

119 lines
3.4 KiB
Go

package header
import (
"encoding/base64"
"fmt"
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
)
type Injector interface {
Inject(http.Header, *sessionsapi.SessionState)
}
type injector struct {
valueInjectors []valueInjector
}
func (i injector) Inject(header http.Header, session *sessionsapi.SessionState) {
for _, injector := range i.valueInjectors {
injector.inject(header, session)
}
}
func NewInjector(headers []options.Header) (Injector, error) {
injectors := []valueInjector{}
for _, header := range headers {
for _, value := range header.Values {
injector, err := newValueinjector(header.Name, value)
if err != nil {
return nil, fmt.Errorf("error building injector for header %q: %v", header.Name, err)
}
injectors = append(injectors, injector)
}
}
return &injector{valueInjectors: injectors}, nil
}
type valueInjector interface {
inject(http.Header, *sessionsapi.SessionState)
}
func newValueinjector(name string, value options.HeaderValue) (valueInjector, error) {
switch {
case value.SecretSource != nil && value.ClaimSource == nil:
return newSecretInjector(name, value.SecretSource)
case value.SecretSource == nil && value.ClaimSource != nil:
return newClaimInjector(name, value.ClaimSource)
default:
return nil, fmt.Errorf("header %q value has multiple entries: only one entry per value is allowed", name)
}
}
type injectorFunc struct {
injectFunc func(http.Header, *sessionsapi.SessionState)
}
func (i *injectorFunc) inject(header http.Header, session *sessionsapi.SessionState) {
i.injectFunc(header, session)
}
func newInjectorFunc(injectFunc func(header http.Header, session *sessionsapi.SessionState)) valueInjector {
return &injectorFunc{injectFunc: injectFunc}
}
func newSecretInjector(name string, source *options.SecretSource) (valueInjector, error) {
value, err := util.GetSecretValue(source)
if err != nil {
return nil, fmt.Errorf("error getting secret value: %v", err)
}
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
header.Add(name, string(value))
}), nil
}
func newClaimInjector(name string, source *options.ClaimSource) (valueInjector, error) {
switch {
case source.BasicAuthPassword != nil:
password, err := util.GetSecretValue(source.BasicAuthPassword)
if err != nil {
return nil, fmt.Errorf("error loading basicAuthPassword: %v", err)
}
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
claimValues := session.GetClaim(source.Claim)
for _, claim := range claimValues {
if claim == "" {
continue
}
auth := claim + ":" + string(password)
header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
}
}), nil
case source.Prefix != "":
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
claimValues := session.GetClaim(source.Claim)
for _, claim := range claimValues {
if claim == "" {
continue
}
header.Add(name, source.Prefix+claim)
}
}), nil
default:
return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
claimValues := session.GetClaim(source.Claim)
for _, claim := range claimValues {
if claim == "" {
continue
}
header.Add(name, claim)
}
}), nil
}
}