From 6743e3991d4a0da3b40ad124877fabfa3234b7a5 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sun, 26 Jul 2020 04:50:39 +0100 Subject: [PATCH] Add header injector middlewares --- pkg/middleware/headers.go | 102 +++++++++ pkg/middleware/headers_test.go | 405 +++++++++++++++++++++++++++++++++ 2 files changed, 507 insertions(+) create mode 100644 pkg/middleware/headers.go create mode 100644 pkg/middleware/headers_test.go diff --git a/pkg/middleware/headers.go b/pkg/middleware/headers.go new file mode 100644 index 00000000..6786c2eb --- /dev/null +++ b/pkg/middleware/headers.go @@ -0,0 +1,102 @@ +package middleware + +import ( + "fmt" + "net/http" + + "github.com/justinas/alice" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header" +) + +func NewRequestHeaderInjector(headers []options.Header) (alice.Constructor, error) { + headerInjector, err := newRequestHeaderInjector(headers) + if err != nil { + return nil, fmt.Errorf("error building request header injector: %v", err) + } + + strip := newStripHeaders(headers) + if strip != nil { + return alice.New(strip, headerInjector).Then, nil + } + return headerInjector, nil +} + +func newStripHeaders(headers []options.Header) alice.Constructor { + headersToStrip := []string{} + for _, header := range headers { + if !header.PreserveRequestValue { + headersToStrip = append(headersToStrip, header.Name) + } + } + + if len(headersToStrip) == 0 { + return nil + } + + return func(next http.Handler) http.Handler { + return stripHeaders(headersToStrip, next) + } +} + +func stripHeaders(headers []string, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + for _, header := range headers { + req.Header.Del(header) + } + next.ServeHTTP(rw, req) + }) +} + +func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, error) { + injector, err := header.NewInjector(headers) + if err != nil { + return nil, fmt.Errorf("error building request injector: %v", err) + } + + return func(next http.Handler) http.Handler { + return injectRequestHeaders(injector, next) + }, nil +} + +func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := GetRequestScope(req) + + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + injector.Inject(req.Header, scope.Session) + next.ServeHTTP(rw, req) + }) +} + +func NewResponseHeaderInjector(headers []options.Header) (alice.Constructor, error) { + headerInjector, err := newResponseHeaderInjector(headers) + if err != nil { + return nil, fmt.Errorf("error building response header injector: %v", err) + } + + return headerInjector, nil +} + +func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, error) { + injector, err := header.NewInjector(headers) + if err != nil { + return nil, fmt.Errorf("error building response injector: %v", err) + } + + return func(next http.Handler) http.Handler { + return injectResponseHeaders(injector, next) + }, nil +} + +func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := GetRequestScope(req) + + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + injector.Inject(rw.Header(), scope.Session) + next.ServeHTTP(rw, req) + }) +} diff --git a/pkg/middleware/headers_test.go b/pkg/middleware/headers_test.go new file mode 100644 index 00000000..15006b1d --- /dev/null +++ b/pkg/middleware/headers_test.go @@ -0,0 +1,405 @@ +package middleware + +import ( + "context" + "encoding/base64" + "net/http" + "net/http/httptest" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Headers Suite", func() { + type headersTableInput struct { + headers []options.Header + initialHeaders http.Header + session *sessionsapi.SessionState + expectedHeaders http.Header + expectedErr string + } + + DescribeTable("the request header injector", + func(in headersTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.session, + } + + // Set up the request with a request scope + req := httptest.NewRequest("", "/", nil) + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + req = req.WithContext(contextWithScope) + req.Header = in.initialHeaders.Clone() + + rw := httptest.NewRecorder() + + // Create the handler with a next handler that will capture the headers + // from the request + var gotHeaders http.Header + injector, err := NewRequestHeaderInjector(in.headers) + if in.expectedErr != "" { + Expect(err).To(MatchError(in.expectedErr)) + return + } + Expect(err).ToNot(HaveOccurred()) + + handler := injector(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders = r.Header.Clone() + })) + handler.ServeHTTP(rw, req) + + Expect(gotHeaders).To(Equal(in.expectedHeaders)) + }, + Entry("with no configured headers", headersTableInput{ + headers: []options.Header{}, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{}, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + "Claim": []string{"IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header (without preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "Claim": []string{"IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header (with preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + PreserveRequestValue: true, + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz", "IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header that's not present (without preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: nil, + expectedHeaders: http.Header{}, + expectedErr: "", + }), + Entry("with a claim valued header that's not present (with preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + PreserveRequestValue: true, + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: nil, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + expectedErr: "", + }), + Entry("with an invalid basicAuthPassword claim valued header", headersTableInput{ + headers: []options.Header{ + { + Name: "X-Auth-Request-Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("basic-password"))), + FromEnv: "SECRET_ENV", + }, + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + User: "user-123", + }, + expectedHeaders: nil, + expectedErr: "error building request header injector: error building request injector: error building injector for header \"X-Auth-Request-Authorization\": error loading basicAuthPassword: secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile", + }), + ) + + DescribeTable("the response header injector", + func(in headersTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.session, + } + + // Set up the request with a request scope + req := httptest.NewRequest("", "/", nil) + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + req = req.WithContext(contextWithScope) + + rw := httptest.NewRecorder() + for key, values := range in.initialHeaders { + for _, value := range values { + rw.Header().Add(key, value) + } + } + + // Create the handler with a next handler that will capture the headers + // from the request + var gotHeaders http.Header + injector, err := NewResponseHeaderInjector(in.headers) + if in.expectedErr != "" { + Expect(err).To(MatchError(in.expectedErr)) + return + } + Expect(err).ToNot(HaveOccurred()) + + handler := injector(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders = w.Header().Clone() + })) + handler.ServeHTTP(rw, req) + + Expect(gotHeaders).To(Equal(in.expectedHeaders)) + }, + Entry("with no configured headers", headersTableInput{ + headers: []options.Header{}, + initialHeaders: http.Header{ + "Foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{}, + expectedHeaders: http.Header{ + "Foo": []string{"bar", "baz"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "Foo": []string{"bar", "baz"}, + "Claim": []string{"IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header (without preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz", "IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header (with preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + PreserveRequestValue: true, + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + IDToken: "IDToken-1234", + }, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz", "IDToken-1234"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header that's not present (without preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: nil, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + expectedErr: "", + }), + Entry("with a claim valued header that's not present (with preservation)", headersTableInput{ + headers: []options.Header{ + { + Name: "Claim", + PreserveRequestValue: true, + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + session: nil, + expectedHeaders: http.Header{ + "Claim": []string{"bar", "baz"}, + }, + expectedErr: "", + }), + Entry("with an invalid basicAuthPassword claim valued header", headersTableInput{ + headers: []options.Header{ + { + Name: "X-Auth-Request-Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("basic-password"))), + FromEnv: "SECRET_ENV", + }, + }, + }, + }, + }, + }, + initialHeaders: http.Header{ + "foo": []string{"bar", "baz"}, + }, + session: &sessionsapi.SessionState{ + User: "user-123", + }, + expectedHeaders: nil, + expectedErr: "error building response header injector: error building response injector: error building injector for header \"X-Auth-Request-Authorization\": error loading basicAuthPassword: secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile", + }), + ) +})