diff --git a/plugin/grpctrace/interceptor.go b/plugin/grpctrace/interceptor.go index 0981954c4..6c8b6315e 100644 --- a/plugin/grpctrace/interceptor.go +++ b/plugin/grpctrace/interceptor.go @@ -133,9 +133,10 @@ const ( type clientStream struct { grpc.ClientStream - desc *grpc.StreamDesc - events chan streamEvent - finished chan error + desc *grpc.StreamDesc + events chan streamEvent + eventsDone chan struct{} + finished chan error receivedMessageID int sentMessageID int @@ -147,11 +148,11 @@ func (w *clientStream) RecvMsg(m interface{}) error { err := w.ClientStream.RecvMsg(m) if err == nil && !w.desc.ServerStreams { - w.events <- streamEvent{receiveEndEvent, nil} + w.sendStreamEvent(receiveEndEvent, nil) } else if err == io.EOF { - w.events <- streamEvent{receiveEndEvent, nil} + w.sendStreamEvent(receiveEndEvent, nil) } else if err != nil { - w.events <- streamEvent{errorEvent, err} + w.sendStreamEvent(errorEvent, err) } else { w.receivedMessageID++ messageReceived.Event(w.Context(), w.receivedMessageID, m) @@ -167,7 +168,7 @@ func (w *clientStream) SendMsg(m interface{}) error { messageSent.Event(w.Context(), w.sentMessageID, m) if err != nil { - w.events <- streamEvent{errorEvent, err} + w.sendStreamEvent(errorEvent, err) } return err @@ -177,7 +178,7 @@ func (w *clientStream) Header() (metadata.MD, error) { md, err := w.ClientStream.Header() if err != nil { - w.events <- streamEvent{errorEvent, err} + w.sendStreamEvent(errorEvent, err) } return md, err @@ -187,9 +188,9 @@ func (w *clientStream) CloseSend() error { err := w.ClientStream.CloseSend() if err != nil { - w.events <- streamEvent{errorEvent, err} + w.sendStreamEvent(errorEvent, err) } else { - w.events <- streamEvent{closeEvent, nil} + w.sendStreamEvent(closeEvent, nil) } return err @@ -201,10 +202,13 @@ const ( ) func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream { - events := make(chan streamEvent, 1) + events := make(chan streamEvent) + eventsDone := make(chan struct{}) finished := make(chan error) go func() { + defer close(eventsDone) + // Both streams have to be closed state := byte(0) @@ -216,12 +220,12 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream state |= receiveEndedState case errorEvent: finished <- event.Err - close(events) + return } if state == clientClosedState|receiveEndedState { finished <- nil - close(events) + return } } }() @@ -230,10 +234,18 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream ClientStream: s, desc: desc, events: events, + eventsDone: eventsDone, finished: finished, } } +func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) { + select { + case <-w.eventsDone: + case w.events <- streamEvent{Type: eventType, Err: err}: + } +} + // StreamClientInterceptor returns a grpc.StreamClientInterceptor suitable // for use in a grpc.Dial call. // diff --git a/plugin/grpctrace/interceptor_test.go b/plugin/grpctrace/interceptor_test.go index 211a9c36e..db12a3a30 100644 --- a/plugin/grpctrace/interceptor_test.go +++ b/plugin/grpctrace/interceptor_test.go @@ -376,6 +376,9 @@ func TestStreamClientInterceptor(t *testing.T) { validate("SENT", events[i].Attributes) validate("RECEIVED", events[i+1].Attributes) } + + // ensure CloseSend can be subsequently called + _ = streamClient.CloseSend() } func TestServerInterceptorError(t *testing.T) {