diff --git a/plugin/othttp/filters/filters.go b/plugin/othttp/filters/filters.go new file mode 100644 index 000000000..de6e56d95 --- /dev/null +++ b/plugin/othttp/filters/filters.go @@ -0,0 +1,154 @@ +// Copyright 2020, OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package filters provides a set of filters useful with the +// othttp.WithFilter() option to control which inbound requests are traced. +package filters + +import ( + "net/http" + "strings" + + "go.opentelemetry.io/otel/plugin/othttp" +) + +// Any takes a list of Filters and returns a Filter that +// returns true if any Filter in the list returns true. +func Any(fs ...othttp.Filter) othttp.Filter { + return func(r *http.Request) bool { + for _, f := range fs { + if f(r) { + return true + } + } + return false + } +} + +// All takes a list of Filters and returns a Filter that +// returns true only if all Filters in the list return true. +func All(fs ...othttp.Filter) othttp.Filter { + return func(r *http.Request) bool { + for _, f := range fs { + if !f(r) { + return false + } + } + return true + } +} + +// None takes a list of Filters and returns a Filter that returns +// true only if none of the Filters in the list return true. +func None(fs ...othttp.Filter) othttp.Filter { + return func(r *http.Request) bool { + for _, f := range fs { + if f(r) { + return false + } + } + return true + } +} + +// Not provides a convenience mechanism for inverting a Filter +func Not(f othttp.Filter) othttp.Filter { + return func(r *http.Request) bool { + return !f(r) + } +} + +// Hostname returns a Filter that returns true if the request's +// hostname matches the provided string. +func Hostname(h string) othttp.Filter { + return func(r *http.Request) bool { + return r.URL.Hostname() == h + } +} + +// Path returns a Filter that returns true if the request's +// path matches the provided string. +func Path(p string) othttp.Filter { + return func(r *http.Request) bool { + return r.URL.Path == p + } +} + +// PathPrefix returns a Filter that returns true if the request's +// path starts with the provided string. +func PathPrefix(p string) othttp.Filter { + return func(r *http.Request) bool { + return strings.HasPrefix(r.URL.Path, p) + } +} + +// Query returns a Filter that returns true if the request +// includes a query parameter k with a value equal to v. +func Query(k, v string) othttp.Filter { + return func(r *http.Request) bool { + for _, qv := range r.URL.Query()[k] { + if v == qv { + return true + } + } + return false + } +} + +// QueryContains returns a Filter that returns true if the request +// includes a query parameter k with a value that contains v. +func QueryContains(k, v string) othttp.Filter { + return func(r *http.Request) bool { + for _, qv := range r.URL.Query()[k] { + if strings.Contains(qv, v) { + return true + } + } + return false + } +} + +// Method returns a Filter that returns true if the request +// method is equal to the provided value. +func Method(m string) othttp.Filter { + return func(r *http.Request) bool { + return m == r.Method + } +} + +// Header returns a Filter that returns true if the request +// includes a header k with a value equal to v. +func Header(k, v string) othttp.Filter { + return func(r *http.Request) bool { + for _, hv := range r.Header.Values(k) { + if v == hv { + return true + } + } + return false + } +} + +// HeaderContains returns a Filter that returns true if the request +// includes a header k with a value that contains v. +func HeaderContains(k, v string) othttp.Filter { + return func(r *http.Request) bool { + for _, hv := range r.Header.Values(k) { + if strings.Contains(hv, v) { + return true + } + } + return false + } +} diff --git a/plugin/othttp/filters/filters_test.go b/plugin/othttp/filters/filters_test.go new file mode 100644 index 000000000..d10dfd1b5 --- /dev/null +++ b/plugin/othttp/filters/filters_test.go @@ -0,0 +1,266 @@ +// Copyright 2020, OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package filters + +import ( + "net/http" + "net/url" + "testing" + + "go.opentelemetry.io/otel/plugin/othttp" +) + +type scenario struct { + name string + filter othttp.Filter + req *http.Request + exp bool +} + +func TestAny(t *testing.T) { + for _, s := range []scenario{ + { + name: "no matching filters", + filter: Any(Path("/foo"), Hostname("bar.baz")), + req: &http.Request{URL: &url.URL{Path: "/boo", Host: "baz.bar:8080"}}, + exp: false, + }, + { + name: "one matching filter", + filter: Any(Path("/foo"), Hostname("bar.baz")), + req: &http.Request{URL: &url.URL{Path: "/foo", Host: "baz.bar:8080"}}, + exp: true, + }, + { + name: "all matching filters", + filter: Any(Path("/foo"), Hostname("bar.baz")), + req: &http.Request{URL: &url.URL{Path: "/foo", Host: "bar.baz:8080"}}, + exp: true, + }, + } { + res := s.filter(s.req) + if s.exp != res { + t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res) + } + } +} + +func TestAll(t *testing.T) { + for _, s := range []scenario{ + { + name: "no matching filters", + filter: All(Path("/foo"), Hostname("bar.baz")), + req: &http.Request{URL: &url.URL{Path: "/boo", Host: "baz.bar:8080"}}, + exp: false, + }, + { + name: "one matching filter", + filter: All(Path("/foo"), Hostname("bar.baz")), + req: &http.Request{URL: &url.URL{Path: "/foo", Host: "baz.bar:8080"}}, + exp: false, + }, + { + name: "all matching filters", + filter: All(Path("/foo"), Hostname("bar.baz")), + req: &http.Request{URL: &url.URL{Path: "/foo", Host: "bar.baz:8080"}}, + exp: true, + }, + } { + res := s.filter(s.req) + if s.exp != res { + t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res) + } + } +} + +func TestNone(t *testing.T) { + for _, s := range []scenario{ + { + name: "no matching filters", + filter: None(Path("/foo"), Hostname("bar.baz")), + req: &http.Request{URL: &url.URL{Path: "/boo", Host: "baz.bar:8080"}}, + exp: true, + }, + { + name: "one matching filter", + filter: None(Path("/foo"), Hostname("bar.baz")), + req: &http.Request{URL: &url.URL{Path: "/foo", Host: "baz.bar:8080"}}, + exp: false, + }, + { + name: "all matching filters", + filter: None(Path("/foo"), Hostname("bar.baz")), + req: &http.Request{URL: &url.URL{Path: "/foo", Host: "bar.baz:8080"}}, + exp: false, + }, + } { + res := s.filter(s.req) + if s.exp != res { + t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res) + } + } +} + +func TestNot(t *testing.T) { + req := &http.Request{URL: &url.URL{Path: "/foo", Host: "bar.baz:8080"}} + filter := Path("/foo") + if filter(req) == Not(filter)(req) { + t.Error("Not filter should invert the result of the supplied filter") + } +} + +func TestPathPrefix(t *testing.T) { + for _, s := range []scenario{ + { + name: "non-matching prefix", + filter: PathPrefix("/foo"), + req: &http.Request{URL: &url.URL{Path: "/boo/far", Host: "baz.bar:8080"}}, + exp: false, + }, + { + name: "matching prefix", + filter: PathPrefix("/foo"), + req: &http.Request{URL: &url.URL{Path: "/foo/bar", Host: "bar.baz:8080"}}, + exp: true, + }, + } { + res := s.filter(s.req) + if s.exp != res { + t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res) + } + } +} + +func TestMethod(t *testing.T) { + for _, s := range []scenario{ + { + name: "non-matching method", + filter: Method(http.MethodGet), + req: &http.Request{Method: http.MethodHead, URL: &url.URL{Path: "/boo/far", Host: "baz.bar:8080"}}, + exp: false, + }, + { + name: "matching method", + filter: Method(http.MethodGet), + req: &http.Request{Method: http.MethodGet, URL: &url.URL{Path: "/boo/far", Host: "baz.bar:8080"}}, + exp: true, + }, + } { + res := s.filter(s.req) + if s.exp != res { + t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res) + } + } +} + +func TestQuery(t *testing.T) { + matching, _ := url.Parse("http://bar.baz:8080/foo/bar?key=value") + nonMatching, _ := url.Parse("http://bar.baz:8080/foo/bar?key=other") + for _, s := range []scenario{ + { + name: "non-matching query parameter", + filter: Query("key", "value"), + req: &http.Request{Method: http.MethodHead, URL: nonMatching}, + exp: false, + }, + { + name: "matching query parameter", + filter: Query("key", "value"), + req: &http.Request{Method: http.MethodGet, URL: matching}, + exp: true, + }, + } { + res := s.filter(s.req) + if s.exp != res { + t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res) + } + } +} + +func TestQueryContains(t *testing.T) { + matching, _ := url.Parse("http://bar.baz:8080/foo/bar?key=value") + nonMatching, _ := url.Parse("http://bar.baz:8080/foo/bar?key=other") + for _, s := range []scenario{ + { + name: "non-matching query parameter", + filter: QueryContains("key", "alu"), + req: &http.Request{Method: http.MethodHead, URL: nonMatching}, + exp: false, + }, + { + name: "matching query parameter", + filter: QueryContains("key", "alu"), + req: &http.Request{Method: http.MethodGet, URL: matching}, + exp: true, + }, + } { + res := s.filter(s.req) + if s.exp != res { + t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res) + } + } +} + +func TestHeader(t *testing.T) { + matching := http.Header{} + matching.Add("key", "value") + nonMatching := http.Header{} + nonMatching.Add("key", "other") + for _, s := range []scenario{ + { + name: "non-matching query parameter", + filter: Header("key", "value"), + req: &http.Request{Method: http.MethodHead, Header: nonMatching}, + exp: false, + }, + { + name: "matching query parameter", + filter: Header("key", "value"), + req: &http.Request{Method: http.MethodGet, Header: matching}, + exp: true, + }, + } { + res := s.filter(s.req) + if s.exp != res { + t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res) + } + } +} + +func TestHeaderContains(t *testing.T) { + matching := http.Header{} + matching.Add("key", "value") + nonMatching := http.Header{} + nonMatching.Add("key", "other") + for _, s := range []scenario{ + { + name: "non-matching query parameter", + filter: HeaderContains("key", "alu"), + req: &http.Request{Method: http.MethodHead, Header: nonMatching}, + exp: false, + }, + { + name: "matching query parameter", + filter: HeaderContains("key", "alu"), + req: &http.Request{Method: http.MethodGet, Header: matching}, + exp: true, + }, + } { + res := s.filter(s.req) + if s.exp != res { + t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res) + } + } +} diff --git a/plugin/othttp/handler.go b/plugin/othttp/handler.go index 01d707540..3ff192d5a 100644 --- a/plugin/othttp/handler.go +++ b/plugin/othttp/handler.go @@ -41,6 +41,10 @@ const ( WriteErrorKey = core.Key("http.write_error") // if an error occurred while writing a reply, the string of the error (io.EOF is not recorded) ) +// Filter is a predicate used to determine whether a given http.request should +// be traced. A Filter must return true if the request should be traced. +type Filter func(*http.Request) bool + // Handler is http middleware that corresponds to the http.Handler interface and // is designed to wrap a http.Mux (or equivalent), while individual routes on // the mux are wrapped with WithRouteTag. A Handler will add various attributes @@ -54,6 +58,7 @@ type Handler struct { spanStartOptions []trace.StartOption readEvent bool writeEvent bool + filters []Filter } // Option function used for setting *optional* Handler properties @@ -93,6 +98,18 @@ func WithSpanOptions(opts ...trace.StartOption) Option { } } +// WithFilter adds a filter to the list of filters used by the handler. +// If any filter indicates to exclude a request then the request will not be +// traced. All filters must allow a request to be traced for a Span to be created. +// If no filters are provided then all requests are traced. +// Filters will be invoked for each processed request, it is advised to make them +// simple and fast. +func WithFilter(f Filter) Option { + return func(h *Handler) { + h.filters = append(h.filters, f) + } +} + type event int // Different types of events that can be recorded, see WithMessageEvents @@ -141,6 +158,14 @@ func NewHandler(handler http.Handler, operation string, opts ...Option) http.Han // ServeHTTP serves HTTP requests (http.Handler) func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + for _, f := range h.filters { + if !f(r) { + // Simply pass through to the handler if a filter rejects the request + h.handler.ServeHTTP(w, r) + return + } + } + opts := append([]trace.StartOption{}, h.spanStartOptions...) // start with the configured options ctx := propagation.ExtractHTTP(r.Context(), h.props, r.Header) diff --git a/plugin/othttp/handler_test.go b/plugin/othttp/handler_test.go index 59e0858b4..fa3229064 100644 --- a/plugin/othttp/handler_test.go +++ b/plugin/othttp/handler_test.go @@ -59,3 +59,44 @@ func TestBasics(t *testing.T) { t.Fatalf("got %q, expected %q", got, expected) } } + +func TestBasicFilter(t *testing.T) { + rr := httptest.NewRecorder() + + var id uint64 + tracer := mocktrace.MockTracer{StartSpanID: &id} + + h := NewHandler( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := io.WriteString(w, "hello world"); err != nil { + t.Fatal(err) + } + }), "test_handler", + WithTracer(&tracer), + WithFilter(func(r *http.Request) bool { + return false + }), + ) + + r, err := http.NewRequest(http.MethodGet, "http://localhost/", nil) + if err != nil { + t.Fatal(err) + } + h.ServeHTTP(rr, r) + if got, expected := rr.Result().StatusCode, http.StatusOK; got != expected { + t.Fatalf("got %d, expected %d", got, expected) + } + if got := rr.Header().Get("Traceparent"); got != "" { + t.Fatal("expected empty trace header") + } + if got, expected := id, uint64(0); got != expected { + t.Fatalf("got %d, expected %d", got, expected) + } + d, err := ioutil.ReadAll(rr.Result().Body) + if err != nil { + t.Fatal(err) + } + if got, expected := string(d), "hello world"; got != expected { + t.Fatalf("got %q, expected %q", got, expected) + } +}