From f9bf364f23da5cfca858c3dc48ec752c0a065ac5 Mon Sep 17 00:00:00 2001
From: Dave McGregor <dave.s.mcgregor@gmail.com>
Date: Wed, 20 May 2020 18:07:53 -0400
Subject: [PATCH] Ensure gRPC ClientStream override methods do not panic.

Previously, the channel used to aggregate the finished state of the stream
could be closed while still open to receiving stream state events.

This removes the closing of the channel, and instead adds a "done" channel that
is used to skip sending to the channel after the receiver is done.
---
 plugin/grpctrace/interceptor.go      | 38 ++++++++++++++++++----------
 plugin/grpctrace/interceptor_test.go |  3 +++
 2 files changed, 28 insertions(+), 13 deletions(-)

diff --git a/plugin/grpctrace/interceptor.go b/plugin/grpctrace/interceptor.go
index 0981954c4..6c8b6315e 100644
--- a/plugin/grpctrace/interceptor.go
+++ b/plugin/grpctrace/interceptor.go
@@ -133,9 +133,10 @@ const (
 type clientStream struct {
 	grpc.ClientStream
 
-	desc     *grpc.StreamDesc
-	events   chan streamEvent
-	finished chan error
+	desc       *grpc.StreamDesc
+	events     chan streamEvent
+	eventsDone chan struct{}
+	finished   chan error
 
 	receivedMessageID int
 	sentMessageID     int
@@ -147,11 +148,11 @@ func (w *clientStream) RecvMsg(m interface{}) error {
 	err := w.ClientStream.RecvMsg(m)
 
 	if err == nil && !w.desc.ServerStreams {
-		w.events <- streamEvent{receiveEndEvent, nil}
+		w.sendStreamEvent(receiveEndEvent, nil)
 	} else if err == io.EOF {
-		w.events <- streamEvent{receiveEndEvent, nil}
+		w.sendStreamEvent(receiveEndEvent, nil)
 	} else if err != nil {
-		w.events <- streamEvent{errorEvent, err}
+		w.sendStreamEvent(errorEvent, err)
 	} else {
 		w.receivedMessageID++
 		messageReceived.Event(w.Context(), w.receivedMessageID, m)
@@ -167,7 +168,7 @@ func (w *clientStream) SendMsg(m interface{}) error {
 	messageSent.Event(w.Context(), w.sentMessageID, m)
 
 	if err != nil {
-		w.events <- streamEvent{errorEvent, err}
+		w.sendStreamEvent(errorEvent, err)
 	}
 
 	return err
@@ -177,7 +178,7 @@ func (w *clientStream) Header() (metadata.MD, error) {
 	md, err := w.ClientStream.Header()
 
 	if err != nil {
-		w.events <- streamEvent{errorEvent, err}
+		w.sendStreamEvent(errorEvent, err)
 	}
 
 	return md, err
@@ -187,9 +188,9 @@ func (w *clientStream) CloseSend() error {
 	err := w.ClientStream.CloseSend()
 
 	if err != nil {
-		w.events <- streamEvent{errorEvent, err}
+		w.sendStreamEvent(errorEvent, err)
 	} else {
-		w.events <- streamEvent{closeEvent, nil}
+		w.sendStreamEvent(closeEvent, nil)
 	}
 
 	return err
@@ -201,10 +202,13 @@ const (
 )
 
 func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream {
-	events := make(chan streamEvent, 1)
+	events := make(chan streamEvent)
+	eventsDone := make(chan struct{})
 	finished := make(chan error)
 
 	go func() {
+		defer close(eventsDone)
+
 		// Both streams have to be closed
 		state := byte(0)
 
@@ -216,12 +220,12 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
 				state |= receiveEndedState
 			case errorEvent:
 				finished <- event.Err
-				close(events)
+				return
 			}
 
 			if state == clientClosedState|receiveEndedState {
 				finished <- nil
-				close(events)
+				return
 			}
 		}
 	}()
@@ -230,10 +234,18 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
 		ClientStream: s,
 		desc:         desc,
 		events:       events,
+		eventsDone:   eventsDone,
 		finished:     finished,
 	}
 }
 
+func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) {
+	select {
+	case <-w.eventsDone:
+	case w.events <- streamEvent{Type: eventType, Err: err}:
+	}
+}
+
 // StreamClientInterceptor returns a grpc.StreamClientInterceptor suitable
 // for use in a grpc.Dial call.
 //
diff --git a/plugin/grpctrace/interceptor_test.go b/plugin/grpctrace/interceptor_test.go
index 211a9c36e..db12a3a30 100644
--- a/plugin/grpctrace/interceptor_test.go
+++ b/plugin/grpctrace/interceptor_test.go
@@ -376,6 +376,9 @@ func TestStreamClientInterceptor(t *testing.T) {
 		validate("SENT", events[i].Attributes)
 		validate("RECEIVED", events[i+1].Attributes)
 	}
+
+	// ensure CloseSend can be subsequently called
+	_ = streamClient.CloseSend()
 }
 
 func TestServerInterceptorError(t *testing.T) {