1
0
mirror of https://github.com/open-telemetry/opentelemetry-go.git synced 2025-01-26 03:52:03 +02:00

Merge pull request #755 from realdave/748-prevent-panic

Ensure gRPC ClientStream override methods do not panic
This commit is contained in:
Tyler Yahn 2020-05-21 08:53:53 -07:00 committed by GitHub
commit 84a21fe9d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 13 deletions

View File

@ -133,9 +133,10 @@ const (
type clientStream struct {
grpc.ClientStream
desc *grpc.StreamDesc
events chan streamEvent
finished chan error
desc *grpc.StreamDesc
events chan streamEvent
eventsDone chan struct{}
finished chan error
receivedMessageID int
sentMessageID int
@ -147,11 +148,11 @@ func (w *clientStream) RecvMsg(m interface{}) error {
err := w.ClientStream.RecvMsg(m)
if err == nil && !w.desc.ServerStreams {
w.events <- streamEvent{receiveEndEvent, nil}
w.sendStreamEvent(receiveEndEvent, nil)
} else if err == io.EOF {
w.events <- streamEvent{receiveEndEvent, nil}
w.sendStreamEvent(receiveEndEvent, nil)
} else if err != nil {
w.events <- streamEvent{errorEvent, err}
w.sendStreamEvent(errorEvent, err)
} else {
w.receivedMessageID++
messageReceived.Event(w.Context(), w.receivedMessageID, m)
@ -167,7 +168,7 @@ func (w *clientStream) SendMsg(m interface{}) error {
messageSent.Event(w.Context(), w.sentMessageID, m)
if err != nil {
w.events <- streamEvent{errorEvent, err}
w.sendStreamEvent(errorEvent, err)
}
return err
@ -177,7 +178,7 @@ func (w *clientStream) Header() (metadata.MD, error) {
md, err := w.ClientStream.Header()
if err != nil {
w.events <- streamEvent{errorEvent, err}
w.sendStreamEvent(errorEvent, err)
}
return md, err
@ -187,9 +188,9 @@ func (w *clientStream) CloseSend() error {
err := w.ClientStream.CloseSend()
if err != nil {
w.events <- streamEvent{errorEvent, err}
w.sendStreamEvent(errorEvent, err)
} else {
w.events <- streamEvent{closeEvent, nil}
w.sendStreamEvent(closeEvent, nil)
}
return err
@ -201,10 +202,13 @@ const (
)
func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream {
events := make(chan streamEvent, 1)
events := make(chan streamEvent)
eventsDone := make(chan struct{})
finished := make(chan error)
go func() {
defer close(eventsDone)
// Both streams have to be closed
state := byte(0)
@ -216,12 +220,12 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
state |= receiveEndedState
case errorEvent:
finished <- event.Err
close(events)
return
}
if state == clientClosedState|receiveEndedState {
finished <- nil
close(events)
return
}
}
}()
@ -230,10 +234,18 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
ClientStream: s,
desc: desc,
events: events,
eventsDone: eventsDone,
finished: finished,
}
}
func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) {
select {
case <-w.eventsDone:
case w.events <- streamEvent{Type: eventType, Err: err}:
}
}
// StreamClientInterceptor returns a grpc.StreamClientInterceptor suitable
// for use in a grpc.Dial call.
//

View File

@ -376,6 +376,9 @@ func TestStreamClientInterceptor(t *testing.T) {
validate("SENT", events[i].Attributes)
validate("RECEIVED", events[i+1].Attributes)
}
// ensure CloseSend can be subsequently called
_ = streamClient.CloseSend()
}
func TestServerInterceptorError(t *testing.T) {