mirror of
https://github.com/open-telemetry/opentelemetry-go.git
synced 2025-01-26 03:52:03 +02:00
Merge pull request #755 from realdave/748-prevent-panic
Ensure gRPC ClientStream override methods do not panic
This commit is contained in:
commit
84a21fe9d7
@ -133,9 +133,10 @@ const (
|
||||
type clientStream struct {
|
||||
grpc.ClientStream
|
||||
|
||||
desc *grpc.StreamDesc
|
||||
events chan streamEvent
|
||||
finished chan error
|
||||
desc *grpc.StreamDesc
|
||||
events chan streamEvent
|
||||
eventsDone chan struct{}
|
||||
finished chan error
|
||||
|
||||
receivedMessageID int
|
||||
sentMessageID int
|
||||
@ -147,11 +148,11 @@ func (w *clientStream) RecvMsg(m interface{}) error {
|
||||
err := w.ClientStream.RecvMsg(m)
|
||||
|
||||
if err == nil && !w.desc.ServerStreams {
|
||||
w.events <- streamEvent{receiveEndEvent, nil}
|
||||
w.sendStreamEvent(receiveEndEvent, nil)
|
||||
} else if err == io.EOF {
|
||||
w.events <- streamEvent{receiveEndEvent, nil}
|
||||
w.sendStreamEvent(receiveEndEvent, nil)
|
||||
} else if err != nil {
|
||||
w.events <- streamEvent{errorEvent, err}
|
||||
w.sendStreamEvent(errorEvent, err)
|
||||
} else {
|
||||
w.receivedMessageID++
|
||||
messageReceived.Event(w.Context(), w.receivedMessageID, m)
|
||||
@ -167,7 +168,7 @@ func (w *clientStream) SendMsg(m interface{}) error {
|
||||
messageSent.Event(w.Context(), w.sentMessageID, m)
|
||||
|
||||
if err != nil {
|
||||
w.events <- streamEvent{errorEvent, err}
|
||||
w.sendStreamEvent(errorEvent, err)
|
||||
}
|
||||
|
||||
return err
|
||||
@ -177,7 +178,7 @@ func (w *clientStream) Header() (metadata.MD, error) {
|
||||
md, err := w.ClientStream.Header()
|
||||
|
||||
if err != nil {
|
||||
w.events <- streamEvent{errorEvent, err}
|
||||
w.sendStreamEvent(errorEvent, err)
|
||||
}
|
||||
|
||||
return md, err
|
||||
@ -187,9 +188,9 @@ func (w *clientStream) CloseSend() error {
|
||||
err := w.ClientStream.CloseSend()
|
||||
|
||||
if err != nil {
|
||||
w.events <- streamEvent{errorEvent, err}
|
||||
w.sendStreamEvent(errorEvent, err)
|
||||
} else {
|
||||
w.events <- streamEvent{closeEvent, nil}
|
||||
w.sendStreamEvent(closeEvent, nil)
|
||||
}
|
||||
|
||||
return err
|
||||
@ -201,10 +202,13 @@ const (
|
||||
)
|
||||
|
||||
func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream {
|
||||
events := make(chan streamEvent, 1)
|
||||
events := make(chan streamEvent)
|
||||
eventsDone := make(chan struct{})
|
||||
finished := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer close(eventsDone)
|
||||
|
||||
// Both streams have to be closed
|
||||
state := byte(0)
|
||||
|
||||
@ -216,12 +220,12 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
|
||||
state |= receiveEndedState
|
||||
case errorEvent:
|
||||
finished <- event.Err
|
||||
close(events)
|
||||
return
|
||||
}
|
||||
|
||||
if state == clientClosedState|receiveEndedState {
|
||||
finished <- nil
|
||||
close(events)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@ -230,10 +234,18 @@ func wrapClientStream(s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream
|
||||
ClientStream: s,
|
||||
desc: desc,
|
||||
events: events,
|
||||
eventsDone: eventsDone,
|
||||
finished: finished,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) {
|
||||
select {
|
||||
case <-w.eventsDone:
|
||||
case w.events <- streamEvent{Type: eventType, Err: err}:
|
||||
}
|
||||
}
|
||||
|
||||
// StreamClientInterceptor returns a grpc.StreamClientInterceptor suitable
|
||||
// for use in a grpc.Dial call.
|
||||
//
|
||||
|
@ -376,6 +376,9 @@ func TestStreamClientInterceptor(t *testing.T) {
|
||||
validate("SENT", events[i].Attributes)
|
||||
validate("RECEIVED", events[i+1].Attributes)
|
||||
}
|
||||
|
||||
// ensure CloseSend can be subsequently called
|
||||
_ = streamClient.CloseSend()
|
||||
}
|
||||
|
||||
func TestServerInterceptorError(t *testing.T) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user