diff --git a/client/rpc_stream.go b/client/rpc_stream.go index f904d2dd..abc4f5f1 100644 --- a/client/rpc_stream.go +++ b/client/rpc_stream.go @@ -83,7 +83,10 @@ func (r *rpcStream) Recv(msg interface{}) error { var resp codec.Message - if err := r.codec.ReadHeader(&resp, codec.Response); err != nil { + r.Unlock() + err := r.codec.ReadHeader(&resp, codec.Response) + r.Lock() + if err != nil { if err == io.EOF && !r.isClosed() { r.err = io.ErrUnexpectedEOF return io.ErrUnexpectedEOF @@ -102,11 +105,17 @@ func (r *rpcStream) Recv(msg interface{}) error { } else { r.err = io.EOF } - if err := r.codec.ReadBody(nil); err != nil { + r.Unlock() + err = r.codec.ReadBody(nil) + r.Lock() + if err != nil { r.err = err } default: - if err := r.codec.ReadBody(msg); err != nil { + r.Unlock() + err = r.codec.ReadBody(msg) + r.Lock() + if err != nil { r.err = err } } @@ -121,11 +130,15 @@ func (r *rpcStream) Error() error { } func (r *rpcStream) Close() error { + r.RLock() + select { case <-r.closed: + r.RUnlock() return nil default: close(r.closed) + r.RUnlock() // send the end of stream message if r.sendEOS { diff --git a/server/rpc_stream.go b/server/rpc_stream.go index a4e64af8..7421fd45 100644 --- a/server/rpc_stream.go +++ b/server/rpc_stream.go @@ -48,13 +48,13 @@ func (r *rpcStream) Send(msg interface{}) error { } func (r *rpcStream) Recv(msg interface{}) error { - r.Lock() - defer r.Unlock() - req := new(codec.Message) req.Type = codec.Request - if err := r.codec.ReadHeader(req, req.Type); err != nil { + err := r.codec.ReadHeader(req, req.Type) + r.Lock() + defer r.Unlock() + if err != nil { // discard body r.codec.ReadBody(nil) r.err = err @@ -67,7 +67,9 @@ func (r *rpcStream) Recv(msg interface{}) error { switch req.Error { case lastStreamResponseError.Error(): // discard body + r.Unlock() r.codec.ReadBody(nil) + r.Lock() r.err = io.EOF return io.EOF default: @@ -77,7 +79,10 @@ func (r *rpcStream) Recv(msg interface{}) error { // we need to stay up to date with sequence numbers r.id = req.Id - if err := r.codec.ReadBody(msg); err != nil { + r.Unlock() + err = r.codec.ReadBody(msg) + r.Lock() + if err != nil { r.err = err return err } diff --git a/server/rpc_stream_test.go b/server/rpc_stream_test.go new file mode 100644 index 00000000..18a0f0a3 --- /dev/null +++ b/server/rpc_stream_test.go @@ -0,0 +1,133 @@ +package server + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "sync" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/micro/go-micro/codec/json" + protoCodec "github.com/micro/go-micro/codec/proto" +) + +// protoStruct implements proto.Message +type protoStruct struct { + Payload string `protobuf:"bytes,1,opt,name=service,proto3" json:"service,omitempty"` +} + +func (m *protoStruct) Reset() { *m = protoStruct{} } +func (m *protoStruct) String() string { return proto.CompactTextString(m) } +func (*protoStruct) ProtoMessage() {} + +// safeBuffer throws away everything and wont Read data back +type safeBuffer struct { + sync.RWMutex + buf []byte + off int +} + +func (b *safeBuffer) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + // Cannot retain p, so we must copy it: + p2 := make([]byte, len(p)) + copy(p2, p) + b.Lock() + b.buf = append(b.buf, p2...) + b.Unlock() + return len(p2), nil +} + +func (b *safeBuffer) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + b.RLock() + n = copy(p, b.buf[b.off:]) + b.RUnlock() + if n == 0 { + return 0, io.EOF + } + b.off += n + return n, nil +} + +func (b *safeBuffer) Close() error { + return nil +} + +func TestRPCStream_Sequence(t *testing.T) { + buffer := new(bytes.Buffer) + rwc := readWriteCloser{ + rbuf: buffer, + wbuf: buffer, + } + codec := json.NewCodec(&rwc) + streamServer := rpcStream{ + codec: codec, + request: &rpcRequest{ + codec: codec, + }, + } + + // Check if sequence is correct + for i := 0; i < 1000; i++ { + if err := streamServer.Send(fmt.Sprintf(`{"test":"value %d"}`, i)); err != nil { + t.Errorf("Unexpected Send error: %s", err) + } + } + + for i := 0; i < 1000; i++ { + var msg string + if err := streamServer.Recv(&msg); err != nil { + t.Errorf("Unexpected Recv error: %s", err) + } + if msg != fmt.Sprintf(`{"test":"value %d"}`, i) { + t.Errorf("Unexpected msg: %s", msg) + } + } +} + +func TestRPCStream_Concurrency(t *testing.T) { + buffer := new(safeBuffer) + codec := protoCodec.NewCodec(buffer) + streamServer := rpcStream{ + codec: codec, + request: &rpcRequest{ + codec: codec, + }, + } + + var wg sync.WaitGroup + // Check if race conditions happen + for i := 0; i < 10; i++ { + wg.Add(2) + + go func() { + for i := 0; i < 50; i++ { + msg := protoStruct{Payload: "test"} + <-time.After(time.Duration(rand.Intn(50)) * time.Millisecond) + if err := streamServer.Send(msg); err != nil { + t.Errorf("Unexpected Send error: %s", err) + } + } + wg.Done() + }() + + go func() { + for i := 0; i < 50; i++ { + <-time.After(time.Duration(rand.Intn(50)) * time.Millisecond) + if err := streamServer.Recv(&protoStruct{}); err != nil { + t.Errorf("Unexpected Recv error: %s", err) + } + } + wg.Done() + }() + } + wg.Wait() +}