diff --git a/client/rpc_stream.go b/client/rpc_stream.go index 2a55820f..515dc5b1 100644 --- a/client/rpc_stream.go +++ b/client/rpc_stream.go @@ -75,10 +75,10 @@ func (r *rpcStream) Send(msg interface{}) error { func (r *rpcStream) Recv(msg interface{}) error { r.Lock() - defer r.Unlock() if r.isClosed() { r.err = errShutdown + r.Unlock() return errShutdown } @@ -90,9 +90,12 @@ func (r *rpcStream) Recv(msg interface{}) error { if err != nil { if err == io.EOF && !r.isClosed() { r.err = io.ErrUnexpectedEOF + r.Unlock() return io.ErrUnexpectedEOF } r.err = err + + r.Unlock() return err } @@ -121,6 +124,7 @@ func (r *rpcStream) Recv(msg interface{}) error { } } + r.Unlock() return r.err } diff --git a/transport/http_transport.go b/transport/http_transport.go index 8b44b161..ed311eb4 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -12,12 +12,13 @@ import ( "sync" "time" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + maddr "go-micro.dev/v4/util/addr" "go-micro.dev/v4/util/buf" mnet "go-micro.dev/v4/util/net" mls "go-micro.dev/v4/util/tls" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) type httpTransport struct { @@ -34,9 +35,10 @@ type httpTransportClient struct { sync.RWMutex // request must be stored for response processing - r chan *http.Request - bl []*http.Request - buff *bufio.Reader + r chan *http.Request + bl []*http.Request + buff *bufio.Reader + closed bool // local/remote ip local string @@ -137,7 +139,12 @@ func (h *httpTransportClient) Recv(m *Message) error { h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)) } + h.Lock() + if h.closed { + return io.EOF + } rsp, err := http.ReadResponse(h.buff, r) + h.Unlock() if err != nil { return err } @@ -173,6 +180,7 @@ func (h *httpTransportClient) Close() error { h.once.Do(func() { h.Lock() h.buff.Reset(nil) + h.closed = true h.Unlock() close(h.r) }) diff --git a/transport/http_transport_test.go b/transport/http_transport_test.go index 85bea2f6..eb922167 100644 --- a/transport/http_transport_test.go +++ b/transport/http_transport_test.go @@ -1,8 +1,10 @@ package transport import ( + "fmt" "io" "net" + "sync" "testing" "time" ) @@ -244,3 +246,77 @@ func TestHTTPTransportTimeout(t *testing.T) { <-done } + +func TestHTTPTransportCloseWhenRecv(t *testing.T) { + tr := NewHTTPTransport() + + l, err := tr.Listen("127.0.0.1:0") + if err != nil { + t.Errorf("Unexpected listen err: %v", err) + } + defer l.Close() + + fn := func(sock Socket) { + defer sock.Close() + + for { + var m Message + if err := sock.Recv(&m); err != nil { + return + } + if err := sock.Send(&m); err != nil { + return + } + } + } + + done := make(chan bool) + + go func() { + if err := l.Accept(fn); err != nil { + select { + case <-done: + default: + t.Errorf("Unexpected accept err: %v", err) + } + } + }() + + c, err := tr.Dial(l.Addr()) + if err != nil { + t.Errorf("Unexpected dial err: %v", err) + } + defer c.Close() + + m := Message{ + Header: map[string]string{ + "Content-Type": "application/json", + }, + Body: []byte(`{"message": "Hello World"}`), + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + var rm Message + + if err := c.Recv(&rm); err != nil { + if err == io.EOF { + return + } + t.Errorf("Unexpected recv err: %v", err) + } + fmt.Println("aa") + } + }() + for i := 1; i < 3; i++ { + if err := c.Send(&m); err != nil { + t.Errorf("Unexpected send err: %v", err) + } + } + close(done) + + c.Close() + wg.Wait() +}