// Copyright The OpenTelemetry Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package grpctrace import ( "context" "sync" "testing" "time" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "go.opentelemetry.io/otel/api/core" export "go.opentelemetry.io/otel/sdk/export/trace" sdktrace "go.opentelemetry.io/otel/sdk/trace" ) type testExporter struct { mu sync.Mutex spanMap map[string]*export.SpanData } func (t *testExporter) ExportSpan(ctx context.Context, s *export.SpanData) { t.mu.Lock() defer t.mu.Unlock() t.spanMap[s.Name] = s } type mockUICInvoker struct { ctx context.Context } func (mcuici *mockUICInvoker) invoker(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { mcuici.ctx = ctx return nil } type mockProtoMessage struct{} func (mm *mockProtoMessage) Reset() { } func (mm *mockProtoMessage) String() string { return "mock" } func (mm *mockProtoMessage) ProtoMessage() { } func TestUnaryClientInterceptor(t *testing.T) { exp := &testExporter{spanMap: make(map[string]*export.SpanData)} tp, _ := sdktrace.NewProvider(sdktrace.WithSyncer(exp), sdktrace.WithConfig(sdktrace.Config{ DefaultSampler: sdktrace.AlwaysSample(), }, )) clientConn, err := grpc.Dial("fake:connection", grpc.WithInsecure()) if err != nil { t.Fatalf("failed to create client connection: %v", err) } tracer := tp.Tracer("grpctrace/client") unaryInterceptor := UnaryClientInterceptor(tracer) req := &mockProtoMessage{} reply := &mockProtoMessage{} uniInterceptorInvoker := &mockUICInvoker{} checks := []struct { name string expectedAttr map[core.Key]core.Value eventsAttr []map[core.Key]core.Value }{ { name: "/github.com.serviceName/bar", expectedAttr: map[core.Key]core.Value{ rpcServiceKey: core.String("serviceName"), netPeerIPKey: core.String("fake"), netPeerPortKey: core.String("connection"), }, eventsAttr: []map[core.Key]core.Value{ { messageTypeKey: core.String("SENT"), messageIDKey: core.Int(1), }, { messageTypeKey: core.String("RECEIVED"), messageIDKey: core.Int(1), }, }, }, { name: "/serviceName/bar", expectedAttr: map[core.Key]core.Value{ rpcServiceKey: core.String("serviceName"), }, eventsAttr: []map[core.Key]core.Value{ { messageTypeKey: core.String("SENT"), messageIDKey: core.Int(1), }, { messageTypeKey: core.String("RECEIVED"), messageIDKey: core.Int(1), }, }, }, { name: "serviceName/bar", expectedAttr: map[core.Key]core.Value{rpcServiceKey: core.String("serviceName")}, }, { name: "invalidName", expectedAttr: map[core.Key]core.Value{rpcServiceKey: core.String("")}, }, { name: "/github.com.foo.serviceName_123/method", expectedAttr: map[core.Key]core.Value{rpcServiceKey: core.String("serviceName_123")}, }, } for _, check := range checks { err = unaryInterceptor(context.Background(), check.name, req, reply, clientConn, uniInterceptorInvoker.invoker) if err != nil { t.Errorf("failed to run unary interceptor: %v", err) continue } spanData, ok := exp.spanMap[check.name] if !ok { t.Errorf("no span data found for name < %s >", check.name) continue } attrs := spanData.Attributes if len(check.expectedAttr) > len(attrs) { t.Errorf("attributes received are less than expected attributes, received %d, expected %d", len(attrs), len(check.expectedAttr)) } for _, attr := range attrs { expectedAttr, ok := check.expectedAttr[attr.Key] if ok { if expectedAttr != attr.Value { t.Errorf("name: %s invalid %s found. expected %s, actual %s", check.name, string(attr.Key), expectedAttr.AsString(), attr.Value.AsString()) } delete(check.expectedAttr, attr.Key) } } // Check if any expected attr not seen if len(check.expectedAttr) > 0 { for attr := range check.expectedAttr { t.Errorf("missing attribute %s in span", string(attr)) } } events := spanData.MessageEvents if len(check.eventsAttr) > len(events) { t.Errorf("events received are less than expected events, received %d, expected %d", len(events), len(check.eventsAttr)) } for event := 0; event < len(check.eventsAttr); event++ { for _, attr := range events[event].Attributes { expectedAttr, ok := check.eventsAttr[event][attr.Key] if ok { if attr.Value != expectedAttr { t.Errorf("invalid value for attribute %s in events, expected %s actual %s", string(attr.Key), attr.Value.AsString(), expectedAttr.AsString()) } delete(check.eventsAttr[event], attr.Key) } } if len(check.eventsAttr[event]) > 0 { for attr := range check.eventsAttr[event] { t.Errorf("missing attribute %s in span event", string(attr)) } } } } } type mockClientStream struct { Desc *grpc.StreamDesc Ctx context.Context } func (mockClientStream) SendMsg(m interface{}) error { return nil } func (mockClientStream) RecvMsg(m interface{}) error { return nil } func (mockClientStream) CloseSend() error { return nil } func (c mockClientStream) Context() context.Context { return c.Ctx } func (mockClientStream) Header() (metadata.MD, error) { return nil, nil } func (mockClientStream) Trailer() metadata.MD { return nil } func TestStreamClientInterceptor(t *testing.T) { exp := &testExporter{spanMap: make(map[string]*export.SpanData)} tp, _ := sdktrace.NewProvider(sdktrace.WithSyncer(exp), sdktrace.WithConfig(sdktrace.Config{ DefaultSampler: sdktrace.AlwaysSample(), }, )) clientConn, err := grpc.Dial("fake:connection", grpc.WithInsecure()) if err != nil { t.Fatalf("failed to create client connection: %v", err) } // tracer tracer := tp.Tracer("grpctrace/Server") streamCI := StreamClientInterceptor(tracer) var mockClStr mockClientStream methodName := "/github.com.serviceName/bar" streamClient, err := streamCI(context.Background(), &grpc.StreamDesc{ServerStreams: true}, clientConn, methodName, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { mockClStr = mockClientStream{Desc: desc, Ctx: ctx} return mockClStr, nil }) if err != nil { t.Fatalf("failed to initialize grpc stream client: %v", err) } // no span exported while stream is open if _, ok := exp.spanMap[methodName]; ok { t.Fatalf("span shouldn't end while stream is open") } req := &mockProtoMessage{} reply := &mockProtoMessage{} // send and receive fake data for i := 0; i < 10; i++ { _ = streamClient.SendMsg(req) _ = streamClient.RecvMsg(reply) } // close client and server stream _ = streamClient.CloseSend() mockClStr.Desc.ServerStreams = false _ = streamClient.RecvMsg(reply) // added retry because span end is called in separate go routine var spanData *export.SpanData for retry := 0; retry < 5; retry++ { ok := false exp.mu.Lock() spanData, ok = exp.spanMap[methodName] exp.mu.Unlock() if ok { break } time.Sleep(time.Second * 1) } if spanData == nil { t.Fatalf("no span data found for name < %s >", methodName) } attrs := spanData.Attributes expectedAttr := map[core.Key]string{ rpcServiceKey: "serviceName", netPeerIPKey: "fake", netPeerPortKey: "connection", } for _, attr := range attrs { expected, ok := expectedAttr[attr.Key] if ok { if expected != attr.Value.AsString() { t.Errorf("name: %s invalid %s found. expected %s, actual %s", methodName, string(attr.Key), expected, attr.Value.AsString()) } } } events := spanData.MessageEvents if len(events) != 20 { t.Fatalf("incorrect number of events expected 20 got %d", len(events)) } for i := 0; i < 20; i += 2 { msgID := i/2 + 1 validate := func(eventName string, attrs []core.KeyValue) { for _, attr := range attrs { if attr.Key == messageTypeKey && attr.Value.AsString() != eventName { t.Errorf("invalid event on index: %d expecting %s event, receive %s event", i, eventName, attr.Value.AsString()) } if attr.Key == messageIDKey && attr.Value != core.Int(msgID) { t.Errorf("invalid id for message event expected %d received %d", msgID, attr.Value.AsInt32()) } } } validate("SENT", events[i].Attributes) validate("RECEIVED", events[i+1].Attributes) } }