From f9bf364f23da5cfca858c3dc48ec752c0a065ac5 Mon Sep 17 00:00:00 2001 From: Dave McGregor Date: Wed, 20 May 2020 18:07:53 -0400 Subject: [PATCH] Ensure gRPC ClientStream override methods do not panic. Previously, the channel used to aggregate the finished state of the stream could be closed while still open to receiving stream state events. This removes the closing of the channel, and instead adds a "done" channel that is used to skip sending to the channel after the receiver is done. --- plugin/grpctrace/interceptor.go | 38 ++++++++++++++++++---------- plugin/grpctrace/interceptor_test.go | 3 +++ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/plugin/grpctrace/interceptor.go b/plugin/grpctrace/interceptor.go index 0981954c4..6c8b6315e 100644 --- a/plugin/grpctrace/interceptor.go +++ b/plugin/grpctrace/interceptor.go @@ -133,9 +133,10 @@ const ( type clientStream struct { grpc.ClientStream - desc *grpc.StreamDesc - events chan streamEvent - finished chan error + desc *grpc.StreamDesc + events chan streamEvent + eventsDone chan struct{} + finished chan error receivedMessageID int sentMessageID int @@ -147,11 +148,11 @@ func (w *clientStream) RecvMsg(m interface{}) error { err := w.ClientStream.RecvMsg(m) if err == nil && !w.desc.ServerStreams { - w.events <- streamEvent{receiveEndEvent, nil} + w.sendStreamEvent(receiveEndEvent, nil) } else if err == io.EOF { - w.events <- streamEvent{receiveEndEvent, nil} + w.sendStreamEvent(receiveEndEvent, nil) } else if err != nil { - w.events <- streamEvent{errorEvent, err} + w.sendStreamEvent(errorEvent, err) } else { w.receivedMessageID++ messageReceived.Event(w.Context(), w.receivedMessageID, m) @@ -167,7 +168,7 @@ func (w *clientStream) SendMsg(m interface{}) error { messageSent.Event(w.Context(), w.sentMessageID, m) if err != nil { - w.events <- streamEvent{errorEvent, err} + w.sendStreamEvent(errorEvent, err) } return err @@ -177,7 +178,7 @@ func (w *clientStream) Header() (metadata.MD, error) { md, err := w.ClientStream.Header() if err != nil { - w.events <- streamEvent{errorEvent, err} + w.sendStreamEvent(errorEvent, err) } return md, err @@ -187,9 +188,9 @@ func (w *clientStream) CloseSend() error { err := w.ClientStream.CloseSend() if err != nil { - w.events <- streamEvent{errorEvent, err} + w.sendStreamEvent(errorEvent, err) } else { - w.events <- streamEvent{closeEvent, nil} + w.sendStreamEvent(closeEvent, nil) } return err @@ -201,10 +202,13 @@ const ( ) func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream { - events := make(chan streamEvent, 1) + events := make(chan streamEvent) + eventsDone := make(chan struct{}) finished := make(chan error) go func() { + defer close(eventsDone) + // Both streams have to be closed state := byte(0) @@ -216,12 +220,12 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream state |= receiveEndedState case errorEvent: finished <- event.Err - close(events) + return } if state == clientClosedState|receiveEndedState { finished <- nil - close(events) + return } } }() @@ -230,10 +234,18 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream ClientStream: s, desc: desc, events: events, + eventsDone: eventsDone, finished: finished, } } +func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) { + select { + case <-w.eventsDone: + case w.events <- streamEvent{Type: eventType, Err: err}: + } +} + // StreamClientInterceptor returns a grpc.StreamClientInterceptor suitable // for use in a grpc.Dial call. // diff --git a/plugin/grpctrace/interceptor_test.go b/plugin/grpctrace/interceptor_test.go index 211a9c36e..db12a3a30 100644 --- a/plugin/grpctrace/interceptor_test.go +++ b/plugin/grpctrace/interceptor_test.go @@ -376,6 +376,9 @@ func TestStreamClientInterceptor(t *testing.T) { validate("SENT", events[i].Attributes) validate("RECEIVED", events[i+1].Attributes) } + + // ensure CloseSend can be subsequently called + _ = streamClient.CloseSend() } func TestServerInterceptorError(t *testing.T) {