1
0
mirror of https://github.com/open-telemetry/opentelemetry-go.git synced 2026-05-22 09:35:21 +02:00

Ensure gRPC ClientStream override methods do not panic.

Previously, the channel used to aggregate the finished state of the stream
could be closed while still open to receiving stream state events.

This removes the closing of the channel, and instead adds a "done" channel that
is used to skip sending to the channel after the receiver is done.
This commit is contained in:
Dave McGregor
2020-05-20 18:07:53 -04:00
parent 5461669733
commit f9bf364f23
2 changed files with 28 additions and 13 deletions
+22 -10
View File
@@ -135,6 +135,7 @@ type clientStream struct {
desc *grpc.StreamDesc desc *grpc.StreamDesc
events chan streamEvent events chan streamEvent
eventsDone chan struct{}
finished chan error finished chan error
receivedMessageID int receivedMessageID int
@@ -147,11 +148,11 @@ func (w *clientStream) RecvMsg(m interface{}) error {
err := w.ClientStream.RecvMsg(m) err := w.ClientStream.RecvMsg(m)
if err == nil && !w.desc.ServerStreams { if err == nil && !w.desc.ServerStreams {
w.events <- streamEvent{receiveEndEvent, nil} w.sendStreamEvent(receiveEndEvent, nil)
} else if err == io.EOF { } else if err == io.EOF {
w.events <- streamEvent{receiveEndEvent, nil} w.sendStreamEvent(receiveEndEvent, nil)
} else if err != nil { } else if err != nil {
w.events <- streamEvent{errorEvent, err} w.sendStreamEvent(errorEvent, err)
} else { } else {
w.receivedMessageID++ w.receivedMessageID++
messageReceived.Event(w.Context(), w.receivedMessageID, m) messageReceived.Event(w.Context(), w.receivedMessageID, m)
@@ -167,7 +168,7 @@ func (w *clientStream) SendMsg(m interface{}) error {
messageSent.Event(w.Context(), w.sentMessageID, m) messageSent.Event(w.Context(), w.sentMessageID, m)
if err != nil { if err != nil {
w.events <- streamEvent{errorEvent, err} w.sendStreamEvent(errorEvent, err)
} }
return err return err
@@ -177,7 +178,7 @@ func (w *clientStream) Header() (metadata.MD, error) {
md, err := w.ClientStream.Header() md, err := w.ClientStream.Header()
if err != nil { if err != nil {
w.events <- streamEvent{errorEvent, err} w.sendStreamEvent(errorEvent, err)
} }
return md, err return md, err
@@ -187,9 +188,9 @@ func (w *clientStream) CloseSend() error {
err := w.ClientStream.CloseSend() err := w.ClientStream.CloseSend()
if err != nil { if err != nil {
w.events <- streamEvent{errorEvent, err} w.sendStreamEvent(errorEvent, err)
} else { } else {
w.events <- streamEvent{closeEvent, nil} w.sendStreamEvent(closeEvent, nil)
} }
return err return err
@@ -201,10 +202,13 @@ const (
) )
func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream { func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream {
events := make(chan streamEvent, 1) events := make(chan streamEvent)
eventsDone := make(chan struct{})
finished := make(chan error) finished := make(chan error)
go func() { go func() {
defer close(eventsDone)
// Both streams have to be closed // Both streams have to be closed
state := byte(0) state := byte(0)
@@ -216,12 +220,12 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
state |= receiveEndedState state |= receiveEndedState
case errorEvent: case errorEvent:
finished <- event.Err finished <- event.Err
close(events) return
} }
if state == clientClosedState|receiveEndedState { if state == clientClosedState|receiveEndedState {
finished <- nil finished <- nil
close(events) return
} }
} }
}() }()
@@ -230,10 +234,18 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
ClientStream: s, ClientStream: s,
desc: desc, desc: desc,
events: events, events: events,
eventsDone: eventsDone,
finished: finished, finished: finished,
} }
} }
func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) {
select {
case <-w.eventsDone:
case w.events <- streamEvent{Type: eventType, Err: err}:
}
}
// StreamClientInterceptor returns a grpc.StreamClientInterceptor suitable // StreamClientInterceptor returns a grpc.StreamClientInterceptor suitable
// for use in a grpc.Dial call. // for use in a grpc.Dial call.
// //
+3
View File
@@ -376,6 +376,9 @@ func TestStreamClientInterceptor(t *testing.T) {
validate("SENT", events[i].Attributes) validate("SENT", events[i].Attributes)
validate("RECEIVED", events[i+1].Attributes) validate("RECEIVED", events[i+1].Attributes)
} }
// ensure CloseSend can be subsequently called
_ = streamClient.CloseSend()
} }
func TestServerInterceptorError(t *testing.T) { func TestServerInterceptorError(t *testing.T) {