diff --git a/CHANGELOG.md b/CHANGELOG.md index c3037b42..d61f800f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ ## Changes since v6.1.1 +- [#705](https://github.com/oauth2-proxy/oauth2-proxy/pull/705) Add generic Header injectors for upstream request and response headers (@JoelSpeed) - [#753](https://github.com/oauth2-proxy/oauth2-proxy/pull/753) Pass resource parameter in login url (@codablock) - [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) Add `--skip-auth-route` configuration option for `METHOD=pathRegex` based allowlists (@NickMeves) - [#575](https://github.com/oauth2-proxy/oauth2-proxy/pull/575) Stop accepting legacy SHA1 signed cookies (@NickMeves) diff --git a/pkg/apis/options/common.go b/pkg/apis/options/common.go new file mode 100644 index 00000000..60d352a5 --- /dev/null +++ b/pkg/apis/options/common.go @@ -0,0 +1,14 @@ +package options + +// SecretSource references an individual secret value. +// Only one source within the struct should be defined at any time. +type SecretSource struct { + // Value expects a base64 encoded string value. + Value []byte + + // FromEnv expects the name of an environment variable. + FromEnv string + + // FromFile expects a path to a file containing the secret value. + FromFile string +} diff --git a/pkg/apis/options/header.go b/pkg/apis/options/header.go new file mode 100644 index 00000000..0b2e1b69 --- /dev/null +++ b/pkg/apis/options/header.go @@ -0,0 +1,44 @@ +package options + +// Header represents an individual header that will be added to a request or +// response header. +type Header struct { + // Name is the header name to be used for this set of values. + // Names should be unique within a list of Headers. + Name string `json:"name"` + + // PreserveRequestValue determines whether any values for this header + // should be preserved for the request to the upstream server. + // This option only takes effet on injected request headers. + // Defaults to false (headers that match this header will be stripped). + PreserveRequestValue bool `json:"preserveRequestValue"` + + // Values contains the desired values for this header + Values []HeaderValue `json:"values"` +} + +// HeaderValue represents a single header value and the sources that can +// make up the header value +type HeaderValue struct { + // Allow users to load the value from a secret source + *SecretSource + + // Allow users to load the value from a session claim + *ClaimSource +} + +// ClaimSource allows loading a header value from a claim within the session +type ClaimSource struct { + // Claim is the name of the claim in the session that the value should be + // loaded from. + Claim string `json:"claim,omitempty"` + + // Prefix is an optional prefix that will be prepended to the value of the + // claim if it is non-empty. + Prefix string `json:"prefix,omitempty"` + + // BasicAuthPassword converts this claim into a basic auth header. + // Note the value of claim will become the basic auth username and the + // basicAuthPassword will be used as the password value. + BasicAuthPassword *SecretSource +} diff --git a/pkg/apis/options/util/util.go b/pkg/apis/options/util/util.go new file mode 100644 index 00000000..918da13a --- /dev/null +++ b/pkg/apis/options/util/util.go @@ -0,0 +1,26 @@ +package util + +import ( + "encoding/base64" + "errors" + "io/ioutil" + "os" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" +) + +// GetSecretValue returns the value of the Secret from its source +func GetSecretValue(source *options.SecretSource) ([]byte, error) { + switch { + case len(source.Value) > 0 && source.FromEnv == "" && source.FromFile == "": + value := make([]byte, base64.StdEncoding.DecodedLen(len(source.Value))) + decoded, err := base64.StdEncoding.Decode(value, source.Value) + return value[:decoded], err + case len(source.Value) == 0 && source.FromEnv != "" && source.FromFile == "": + return []byte(os.Getenv(source.FromEnv)), nil + case len(source.Value) == 0 && source.FromEnv == "" && source.FromFile != "": + return ioutil.ReadFile(source.FromFile) + default: + return nil, errors.New("secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile") + } +} diff --git a/pkg/apis/options/util/util_suite_test.go b/pkg/apis/options/util/util_suite_test.go new file mode 100644 index 00000000..75f53dbb --- /dev/null +++ b/pkg/apis/options/util/util_suite_test.go @@ -0,0 +1,16 @@ +package util + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestUtilSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Options Util Suite") +} diff --git a/pkg/apis/options/util/util_test.go b/pkg/apis/options/util/util_test.go new file mode 100644 index 00000000..5ca76a04 --- /dev/null +++ b/pkg/apis/options/util/util_test.go @@ -0,0 +1,88 @@ +package util + +import ( + "encoding/base64" + "io/ioutil" + "os" + "path" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("GetSecretValue", func() { + var fileDir string + const secretEnvKey = "SECRET_ENV_KEY" + const secretEnvValue = "secret-env-value" + var secretFileValue = []byte("secret-file-value") + + BeforeEach(func() { + os.Setenv(secretEnvKey, secretEnvValue) + + var err error + fileDir, err = ioutil.TempDir("", "oauth2-proxy-util-get-secret-value") + Expect(err).ToNot(HaveOccurred()) + Expect(ioutil.WriteFile(path.Join(fileDir, "secret-file"), secretFileValue, 0600)).To(Succeed()) + }) + + AfterEach(func() { + os.Unsetenv(secretEnvKey) + os.RemoveAll(fileDir) + }) + + It("returns the correct value from base64", func() { + originalValue := []byte("secret-value-1") + b64Value := base64.StdEncoding.EncodeToString((originalValue)) + + // Once encoded, the originalValue could have a decoded length longer than + // its actual length, ensure we trim this. + // This assertion ensures we are testing the triming + Expect(len(originalValue)).To(BeNumerically("<", base64.StdEncoding.DecodedLen(len(b64Value)))) + + value, err := GetSecretValue(&options.SecretSource{ + Value: []byte(b64Value), + }) + Expect(err).ToNot(HaveOccurred()) + Expect(value).To(Equal(originalValue)) + }) + + It("returns the correct value from the environment", func() { + value, err := GetSecretValue(&options.SecretSource{ + FromEnv: secretEnvKey, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(value).To(BeEquivalentTo(secretEnvValue)) + }) + + It("returns the correct value from a file", func() { + value, err := GetSecretValue(&options.SecretSource{ + FromFile: path.Join(fileDir, "secret-file"), + }) + Expect(err).ToNot(HaveOccurred()) + Expect(value).To(Equal(secretFileValue)) + }) + + It("when the file does not exist", func() { + value, err := GetSecretValue(&options.SecretSource{ + FromFile: path.Join(fileDir, "not-exist"), + }) + Expect(err).To(HaveOccurred()) + Expect(value).To(BeEmpty()) + }) + + It("with no source set", func() { + value, err := GetSecretValue(&options.SecretSource{}) + Expect(err).To(MatchError("secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile")) + Expect(value).To(BeEmpty()) + }) + + It("with multiple sources set", func() { + value, err := GetSecretValue(&options.SecretSource{ + FromEnv: secretEnvKey, + FromFile: path.Join(fileDir, "secret-file"), + }) + Expect(err).To(MatchError("secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile")) + Expect(value).To(BeEmpty()) + }) +}) diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index c3db8994..03bc747a 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -69,6 +69,36 @@ func (s *SessionState) String() string { return o + "}" } +func (s *SessionState) GetClaim(claim string) []string { + if s == nil { + return []string{} + } + switch claim { + case "access_token": + return []string{s.AccessToken} + case "id_token": + return []string{s.IDToken} + case "created_at": + return []string{s.CreatedAt.String()} + case "expires_on": + return []string{s.ExpiresOn.String()} + case "refresh_token": + return []string{s.RefreshToken} + case "email": + return []string{s.Email} + case "user": + return []string{s.User} + case "groups": + groups := make([]string, len(s.Groups)) + copy(groups, s.Groups) + return groups + case "preferred_username": + return []string{s.PreferredUsername} + default: + return []string{} + } +} + // EncodeSessionState returns an encrypted, lz4 compressed, MessagePack encoded session func (s *SessionState) EncodeSessionState(c encryption.Cipher, compress bool) ([]byte, error) { packed, err := msgpack.Marshal(s) diff --git a/pkg/header/header_suite_test.go b/pkg/header/header_suite_test.go new file mode 100644 index 00000000..3d05cd02 --- /dev/null +++ b/pkg/header/header_suite_test.go @@ -0,0 +1,37 @@ +package header + +import ( + "io/ioutil" + "os" + "path" + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var ( + filesDir string +) + +func TestHeaderSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Header") +} + +var _ = BeforeSuite(func() { + os.Setenv("SECRET_ENV", "super-secret-env") + + dir, err := ioutil.TempDir("", "oauth2-proxy-header-suite") + Expect(err).ToNot(HaveOccurred()) + Expect(ioutil.WriteFile(path.Join(dir, "secret-file"), []byte("super-secret-file"), 0644)).To(Succeed()) + filesDir = dir +}) + +var _ = AfterSuite(func() { + os.Unsetenv("SECRET_ENV") + Expect(os.RemoveAll(filesDir)).To(Succeed()) +}) diff --git a/pkg/header/injector.go b/pkg/header/injector.go new file mode 100644 index 00000000..9c6e2fcd --- /dev/null +++ b/pkg/header/injector.go @@ -0,0 +1,118 @@ +package header + +import ( + "encoding/base64" + "fmt" + "net/http" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" +) + +type Injector interface { + Inject(http.Header, *sessionsapi.SessionState) +} + +type injector struct { + valueInjectors []valueInjector +} + +func (i injector) Inject(header http.Header, session *sessionsapi.SessionState) { + for _, injector := range i.valueInjectors { + injector.inject(header, session) + } +} + +func NewInjector(headers []options.Header) (Injector, error) { + injectors := []valueInjector{} + for _, header := range headers { + for _, value := range header.Values { + injector, err := newValueinjector(header.Name, value) + if err != nil { + return nil, fmt.Errorf("error building injector for header %q: %v", header.Name, err) + } + injectors = append(injectors, injector) + } + } + + return &injector{valueInjectors: injectors}, nil +} + +type valueInjector interface { + inject(http.Header, *sessionsapi.SessionState) +} + +func newValueinjector(name string, value options.HeaderValue) (valueInjector, error) { + switch { + case value.SecretSource != nil && value.ClaimSource == nil: + return newSecretInjector(name, value.SecretSource) + case value.SecretSource == nil && value.ClaimSource != nil: + return newClaimInjector(name, value.ClaimSource) + default: + return nil, fmt.Errorf("header %q value has multiple entries: only one entry per value is allowed", name) + } +} + +type injectorFunc struct { + injectFunc func(http.Header, *sessionsapi.SessionState) +} + +func (i *injectorFunc) inject(header http.Header, session *sessionsapi.SessionState) { + i.injectFunc(header, session) +} + +func newInjectorFunc(injectFunc func(header http.Header, session *sessionsapi.SessionState)) valueInjector { + return &injectorFunc{injectFunc: injectFunc} +} + +func newSecretInjector(name string, source *options.SecretSource) (valueInjector, error) { + value, err := util.GetSecretValue(source) + if err != nil { + return nil, fmt.Errorf("error getting secret value: %v", err) + } + + return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { + header.Add(name, string(value)) + }), nil +} + +func newClaimInjector(name string, source *options.ClaimSource) (valueInjector, error) { + switch { + case source.BasicAuthPassword != nil: + password, err := util.GetSecretValue(source.BasicAuthPassword) + if err != nil { + return nil, fmt.Errorf("error loading basicAuthPassword: %v", err) + } + return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { + claimValues := session.GetClaim(source.Claim) + for _, claim := range claimValues { + if claim == "" { + continue + } + auth := claim + ":" + string(password) + header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) + } + }), nil + case source.Prefix != "": + return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { + claimValues := session.GetClaim(source.Claim) + for _, claim := range claimValues { + if claim == "" { + continue + } + header.Add(name, source.Prefix+claim) + } + }), nil + default: + return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { + claimValues := session.GetClaim(source.Claim) + for _, claim := range claimValues { + if claim == "" { + continue + } + header.Add(name, claim) + } + }), nil + } +} diff --git a/pkg/header/injector_test.go b/pkg/header/injector_test.go new file mode 100644 index 00000000..af034fd9 --- /dev/null +++ b/pkg/header/injector_test.go @@ -0,0 +1,417 @@ +package header + +import ( + "encoding/base64" + "errors" + "net/http" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Injector Suite", func() { + Context("NewInjector", func() { + type newInjectorTableInput struct { + headers []options.Header + initialHeaders http.Header + session *sessionsapi.SessionState + expectedHeaders http.Header + expectedErr error + } + + DescribeTable("creates an injector", + func(in newInjectorTableInput) { + injector, err := NewInjector(in.headers) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + Expect(injector).To(BeNil()) + return + } + + Expect(err).ToNot(HaveOccurred()) + Expect(injector).ToNot(BeNil()) + + headers := in.initialHeaders.Clone() + injector.Inject(headers, in.session) + Expect(headers).To(Equal(in.expectedHeaders)) + }, + Entry("with no configured headers", newInjectorTableInput{ + headers: []options.Header{}, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{}, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + expectedErr: nil, + }), + Entry("with a static valued header from base64", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "Secret", + Values: []options.HeaderValue{ + { + SecretSource: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("super-secret"))), + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{}, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + "Secret": []string{"super-secret"}, + }, + expectedErr: nil, + }), + Entry("with a static valued header from env", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "Secret", + Values: []options.HeaderValue{ + { + SecretSource: &options.SecretSource{ + FromEnv: "SECRET_ENV", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{}, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + "Secret": []string{"super-secret-env"}, + }, + expectedErr: nil, + }), + Entry("with a claim valued header", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + "Claim": []string{"IDToken-1234"}, + }, + expectedErr: nil, + }), + Entry("with a claim valued header and a nil session", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: nil, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + expectedErr: nil, + }), + Entry("with a prefixed claim valued header", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + Prefix: "Bearer ", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + "Claim": []string{"Bearer IDToken-1234"}, + }, + expectedErr: nil, + }), + Entry("with a prefixed claim valued header missing the claim", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "idToken", + Prefix: "Bearer ", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{}, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + expectedErr: nil, + }), + Entry("with a basicAuthPassword and claim valued header", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "X-Auth-Request-Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("basic-password"))), + }, + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + User: "user-123", + }, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + "X-Auth-Request-Authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("user-123:basic-password"))}, + }, + expectedErr: nil, + }), + Entry("with a basicAuthPassword and claim valued header missing the claim", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "X-Auth-Request-Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("basic-password"))), + }, + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{}, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + expectedErr: nil, + }), + Entry("with a header that already exists", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "X-Auth-Request-User", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "X-Auth-Request-User": []string{"user"}, + }, + session: &sessionsapi.SessionState{ + User: "user-123", + }, + expectedHeaders: http.Header{ + "X-Auth-Request-User": []string{"user", "user-123"}, + }, + expectedErr: nil, + }), + Entry("with a claim and secret valued header value", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + SecretSource: &options.SecretSource{ + FromEnv: "SECRET_ENV", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: nil, + expectedErr: errors.New("error building injector for header \"Claim\": header \"Claim\" value has multiple entries: only one entry per value is allowed"), + }), + Entry("with an invalid static valued header", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "Secret", + Values: []options.HeaderValue{ + { + SecretSource: &options.SecretSource{ + FromEnv: "SECRET_ENV", + FromFile: "secret-file", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{}, + expectedHeaders: nil, + expectedErr: errors.New("error building injector for header \"Secret\": error getting secret value: secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile"), + }), + Entry("with an invalid basicAuthPassword claim valued header", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "X-Auth-Request-Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("basic-password"))), + FromEnv: "SECRET_ENV", + }, + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + User: "user-123", + }, + expectedHeaders: nil, + expectedErr: errors.New("error building injector for header \"X-Auth-Request-Authorization\": error loading basicAuthPassword: secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile"), + }), + Entry("with a mix of configured headers", newInjectorTableInput{ + headers: []options.Header{ + { + Name: "X-Auth-Request-Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("basic-password"))), + }, + }, + }, + }, + }, + { + Name: "X-Auth-Request-User", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + }, + }, + }, + }, + { + Name: "X-Auth-Request-Email", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "email", + }, + }, + }, + }, + { + Name: "X-Auth-Request-Version-Info", + Values: []options.HeaderValue{ + { + SecretSource: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("major=1"))), + }, + }, + { + SecretSource: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("minor=2"))), + }, + }, + { + SecretSource: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("patch=3"))), + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + User: "user-123", + Email: "user@example.com", + }, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + "X-Auth-Request-Authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("user-123:basic-password"))}, + "X-Auth-Request-User": []string{"user-123"}, + "X-Auth-Request-Email": []string{"user@example.com"}, + "X-Auth-Request-Version-Info": []string{"major=1", "minor=2", "patch=3"}, + }, + expectedErr: nil, + }), + ) + }) +}) diff --git a/pkg/middleware/headers.go b/pkg/middleware/headers.go new file mode 100644 index 00000000..6786c2eb --- /dev/null +++ b/pkg/middleware/headers.go @@ -0,0 +1,102 @@ +package middleware + +import ( + "fmt" + "net/http" + + "github.com/justinas/alice" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header" +) + +func NewRequestHeaderInjector(headers []options.Header) (alice.Constructor, error) { + headerInjector, err := newRequestHeaderInjector(headers) + if err != nil { + return nil, fmt.Errorf("error building request header injector: %v", err) + } + + strip := newStripHeaders(headers) + if strip != nil { + return alice.New(strip, headerInjector).Then, nil + } + return headerInjector, nil +} + +func newStripHeaders(headers []options.Header) alice.Constructor { + headersToStrip := []string{} + for _, header := range headers { + if !header.PreserveRequestValue { + headersToStrip = append(headersToStrip, header.Name) + } + } + + if len(headersToStrip) == 0 { + return nil + } + + return func(next http.Handler) http.Handler { + return stripHeaders(headersToStrip, next) + } +} + +func stripHeaders(headers []string, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + for _, header := range headers { + req.Header.Del(header) + } + next.ServeHTTP(rw, req) + }) +} + +func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, error) { + injector, err := header.NewInjector(headers) + if err != nil { + return nil, fmt.Errorf("error building request injector: %v", err) + } + + return func(next http.Handler) http.Handler { + return injectRequestHeaders(injector, next) + }, nil +} + +func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := GetRequestScope(req) + + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + injector.Inject(req.Header, scope.Session) + next.ServeHTTP(rw, req) + }) +} + +func NewResponseHeaderInjector(headers []options.Header) (alice.Constructor, error) { + headerInjector, err := newResponseHeaderInjector(headers) + if err != nil { + return nil, fmt.Errorf("error building response header injector: %v", err) + } + + return headerInjector, nil +} + +func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, error) { + injector, err := header.NewInjector(headers) + if err != nil { + return nil, fmt.Errorf("error building response injector: %v", err) + } + + return func(next http.Handler) http.Handler { + return injectResponseHeaders(injector, next) + }, nil +} + +func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := GetRequestScope(req) + + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + injector.Inject(rw.Header(), scope.Session) + next.ServeHTTP(rw, req) + }) +} diff --git a/pkg/middleware/headers_test.go b/pkg/middleware/headers_test.go new file mode 100644 index 00000000..15006b1d --- /dev/null +++ b/pkg/middleware/headers_test.go @@ -0,0 +1,405 @@ +package middleware + +import ( + "context" + "encoding/base64" + "net/http" + "net/http/httptest" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Headers Suite", func() { + type headersTableInput struct { + headers []options.Header + initialHeaders http.Header + session *sessionsapi.SessionState + expectedHeaders http.Header + expectedErr string + } + + DescribeTable("the request header injector", + func(in headersTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.session, + } + + // Set up the request with a request scope + req := httptest.NewRequest("", "/", nil) + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + req = req.WithContext(contextWithScope) + req.Header = in.initialHeaders.Clone() + + rw := httptest.NewRecorder() + + // Create the handler with a next handler that will capture the headers + // from the request + var gotHeaders http.Header + injector, err := NewRequestHeaderInjector(in.headers) + if in.expectedErr != "" { + Expect(err).To(MatchError(in.expectedErr)) + return + } + Expect(err).ToNot(HaveOccurred()) + + handler := injector(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders = r.Header.Clone() + })) + handler.ServeHTTP(rw, req) + + Expect(gotHeaders).To(Equal(in.expectedHeaders)) + }, + Entry("with no configured headers", headersTableInput{ + headers: []options.Header{}, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{}, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + "Claim": []string{"IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header (without preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "Claim": []string{"IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header (with preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + PreserveRequestValue: true, + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz", "IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header that's not present (without preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: nil, + expectedHeaders: http.Header{}, + expectedErr: "", + }), + Entry("with a claim valued header that's not present (with preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + PreserveRequestValue: true, + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: nil, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + expectedErr: "", + }), + Entry("with an invalid basicAuthPassword claim valued header", headersTableInput{ + headers: []options.Header{ + { + Name: "X-Auth-Request-Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("basic-password"))), + FromEnv: "SECRET_ENV", + }, + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + User: "user-123", + }, + expectedHeaders: nil, + expectedErr: "error building request header injector: error building request injector: error building injector for header \"X-Auth-Request-Authorization\": error loading basicAuthPassword: secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile", + }), + ) + + DescribeTable("the response header injector", + func(in headersTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.session, + } + + // Set up the request with a request scope + req := httptest.NewRequest("", "/", nil) + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + req = req.WithContext(contextWithScope) + + rw := httptest.NewRecorder() + for key, values := range in.initialHeaders { + for _, value := range values { + rw.Header().Add(key, value) + } + } + + // Create the handler with a next handler that will capture the headers + // from the request + var gotHeaders http.Header + injector, err := NewResponseHeaderInjector(in.headers) + if in.expectedErr != "" { + Expect(err).To(MatchError(in.expectedErr)) + return + } + Expect(err).ToNot(HaveOccurred()) + + handler := injector(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders = w.Header().Clone() + })) + handler.ServeHTTP(rw, req) + + Expect(gotHeaders).To(Equal(in.expectedHeaders)) + }, + Entry("with no configured headers", headersTableInput{ + headers: []options.Header{}, + initialHeaders: http.Header{ + "Foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{}, + expectedHeaders: http.Header{ + "Foo": []string{"bar", "baz"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "Foo": []string{"bar", "baz"}, + "Claim": []string{"IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header (without preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz", "IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header (with preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + PreserveRequestValue: true, + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz", "IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header that's not present (without preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: nil, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header that's not present (with preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + PreserveRequestValue: true, + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: nil, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + expectedErr: "", + }), + Entry("with an invalid basicAuthPassword claim valued header", headersTableInput{ + headers: []options.Header{ + { + Name: "X-Auth-Request-Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("basic-password"))), + FromEnv: "SECRET_ENV", + }, + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + User: "user-123", + }, + expectedHeaders: nil, + expectedErr: "error building response header injector: error building response injector: error building injector for header \"X-Auth-Request-Authorization\": error loading basicAuthPassword: secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile", + }), + ) +})