From ec6a47c8945efcc7d739816b9fbb4436f377eb37 Mon Sep 17 00:00:00 2001 From: Ak-Army Date: Fri, 15 Jul 2022 12:00:13 +0200 Subject: [PATCH] HTTP Transport make streaming truly bidirectional (#2528) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [fix] http transport can send when wait for recv * [fix] http transport can send multiple message and recevie them. Do not block send and receive on stream mode * [fix] http transport can close the connection when recv is in progress, add tests Co-authored-by: Hunyadvári Péter --- transport/http_transport.go | 50 +++++++++++++----- transport/http_transport_test.go | 87 ++++++++++++++++++++++++++++++-- 2 files changed, 122 insertions(+), 15 deletions(-) diff --git a/transport/http_transport.go b/transport/http_transport.go index 449443f0..1c454f90 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -13,12 +13,13 @@ import ( "sync" "time" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + maddr "github.com/asim/go-micro/v3/util/addr" "github.com/asim/go-micro/v3/util/buf" mnet "github.com/asim/go-micro/v3/util/net" mls "github.com/asim/go-micro/v3/util/tls" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) type httpTransport struct { @@ -103,14 +104,20 @@ func (h *httpTransportClient) Send(m *Message) error { Host: h.addr, } - h.Lock() - h.bl = append(h.bl, req) - select { - case h.r <- h.bl[0]: - h.bl = h.bl[1:] - default: + if !h.dialOpts.Stream { + h.Lock() + if h.closed { + h.Unlock() + return io.EOF + } + h.bl = append(h.bl, req) + select { + case h.r <- h.bl[0]: + h.bl = h.bl[1:] + default: + } + h.Unlock() } - h.Unlock() // set timeout if its greater than 0 if h.ht.opts.Timeout > time.Duration(0) { @@ -129,7 +136,14 @@ func (h *httpTransportClient) Recv(m *Message) error { if !h.dialOpts.Stream { rc, ok := <-h.r if !ok { - return io.EOF + h.Lock() + if len(h.bl) == 0 { + h.Unlock() + return io.EOF + } + rc = h.bl[0] + h.bl = h.bl[1:] + h.Unlock() } r = rc } @@ -141,6 +155,7 @@ func (h *httpTransportClient) Recv(m *Message) error { h.Lock() if h.closed { + h.Unlock() return io.EOF } rsp, err := http.ReadResponse(h.buff, r) @@ -177,6 +192,17 @@ func (h *httpTransportClient) Recv(m *Message) error { } func (h *httpTransportClient) Close() error { + if !h.dialOpts.Stream { + h.once.Do(func() { + h.Lock() + h.buff.Reset(nil) + h.closed = true + h.Unlock() + close(h.r) + }) + return h.conn.Close() + } + err := h.conn.Close() h.once.Do(func() { h.Lock() h.buff.Reset(nil) @@ -184,7 +210,7 @@ func (h *httpTransportClient) Close() error { h.Unlock() close(h.r) }) - return h.conn.Close() + return err } func (h *httpTransportSocket) Local() string { @@ -523,7 +549,7 @@ func (h *httpTransport) Dial(addr string, opts ...DialOption) (Client, error) { conn: conn, buff: bufio.NewReader(conn), dialOpts: dopts, - r: make(chan *http.Request, 1), + r: make(chan *http.Request, 100), local: conn.LocalAddr().String(), remote: conn.RemoteAddr().String(), }, nil diff --git a/transport/http_transport_test.go b/transport/http_transport_test.go index eb922167..debe5d90 100644 --- a/transport/http_transport_test.go +++ b/transport/http_transport_test.go @@ -1,7 +1,6 @@ package transport import ( - "fmt" "io" "net" "sync" @@ -305,9 +304,7 @@ func TestHTTPTransportCloseWhenRecv(t *testing.T) { if err == io.EOF { return } - t.Errorf("Unexpected recv err: %v", err) } - fmt.Println("aa") } }() for i := 1; i < 3; i++ { @@ -320,3 +317,87 @@ func TestHTTPTransportCloseWhenRecv(t *testing.T) { c.Close() wg.Wait() } + +func TestHTTPTransportMultipleSendWhenRecv(t *testing.T) { + tr := NewHTTPTransport() + + l, err := tr.Listen("127.0.0.1:0") + if err != nil { + t.Errorf("Unexpected listen err: %v", err) + } + defer l.Close() + + readyToSend := make(chan struct{}) + m := Message{ + Header: map[string]string{ + "Content-Type": "application/json", + }, + Body: []byte(`{"message": "Hello World"}`), + } + + wgSend := sync.WaitGroup{} + fn := func(sock Socket) { + defer sock.Close() + + for { + var mr Message + if err := sock.Recv(&mr); err != nil { + return + } + wgSend.Add(1) + go func() { + defer wgSend.Done() + <-readyToSend + if err := sock.Send(&m); err != nil { + return + } + }() + } + } + + done := make(chan bool) + + go func() { + if err := l.Accept(fn); err != nil { + select { + case <-done: + default: + t.Errorf("Unexpected accept err: %v", err) + } + } + }() + + c, err := tr.Dial(l.Addr(), WithStream()) + if err != nil { + t.Errorf("Unexpected dial err: %v", err) + } + defer c.Close() + + var wg sync.WaitGroup + wg.Add(1) + readyForRecv := make(chan struct{}) + go func() { + defer wg.Done() + close(readyForRecv) + for { + var rm Message + if err := c.Recv(&rm); err != nil { + if err == io.EOF { + return + } + } + } + }() + <-readyForRecv + for i := 0; i < 3; i++ { + if err := c.Send(&m); err != nil { + t.Errorf("Unexpected send err: %v", err) + } + } + close(readyToSend) + wgSend.Wait() + close(done) + + c.Close() + wg.Wait() +}