From b59125b5b7862553449dbd61669472ffe1f73640 Mon Sep 17 00:00:00 2001 From: Jesse Li Date: Thu, 2 Jan 2020 19:36:25 -0500 Subject: [PATCH] Rip out reader code --- client/client.go | 23 +++++++---------------- handshake/handshake.go | 3 +-- handshake/handshake_test.go | 3 +-- message/message.go | 3 +-- message/message_test.go | 3 +-- p2p/p2p.go | 17 +---------------- 6 files changed, 12 insertions(+), 40 deletions(-) diff --git a/client/client.go b/client/client.go index 8c2d1de..298e393 100644 --- a/client/client.go +++ b/client/client.go @@ -1,7 +1,6 @@ package client import ( - "bufio" "bytes" "fmt" "net" @@ -24,10 +23,9 @@ type Client struct { peer peers.Peer infoHash [20]byte peerID [20]byte - reader *bufio.Reader } -func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte) (*handshake.Handshake, error) { +func completeHandshake(conn net.Conn, infohash, peerID [20]byte) (*handshake.Handshake, error) { conn.SetDeadline(time.Now().Add(3 * time.Second)) defer conn.SetDeadline(time.Time{}) // Disable the deadline @@ -37,7 +35,7 @@ func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte return nil, err } - res, err := handshake.Read(r) + res, err := handshake.Read(conn) if err != nil { return nil, err } @@ -47,11 +45,11 @@ func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte return res, nil } -func recvBitfield(conn net.Conn, r *bufio.Reader) (bitfield.Bitfield, error) { +func recvBitfield(conn net.Conn) (bitfield.Bitfield, error) { conn.SetDeadline(time.Now().Add(5 * time.Second)) defer conn.SetDeadline(time.Time{}) // Disable the deadline - msg, err := message.Read(r) + msg, err := message.Read(conn) if err != nil { return nil, err } @@ -71,15 +69,14 @@ func New(peer peers.Peer, peerID, infoHash [20]byte) (*Client, error) { if err != nil { return nil, err } - reader := bufio.NewReader(conn) - _, err = completeHandshake(conn, reader, infoHash, peerID) + _, err = completeHandshake(conn, infoHash, peerID) if err != nil { conn.Close() return nil, err } - bf, err := recvBitfield(conn, reader) + bf, err := recvBitfield(conn) if err != nil { conn.Close() return nil, err @@ -92,18 +89,12 @@ func New(peer peers.Peer, peerID, infoHash [20]byte) (*Client, error) { peer: peer, infoHash: infoHash, peerID: peerID, - reader: reader, }, nil } -// HasNext returns true if there are unread messages from the peer -func (c *Client) HasNext() bool { - return c.reader.Buffered() > 0 -} - // Read reads and consumes a message from the connection func (c *Client) Read() (*message.Message, error) { - msg, err := message.Read(c.reader) + msg, err := message.Read(c.Conn) return msg, err } diff --git a/handshake/handshake.go b/handshake/handshake.go index 4e05420..6885c58 100644 --- a/handshake/handshake.go +++ b/handshake/handshake.go @@ -1,7 +1,6 @@ package handshake import ( - "bufio" "fmt" "io" ) @@ -36,7 +35,7 @@ func (h *Handshake) Serialize() []byte { } // Read parses a handshake from a stream -func Read(r *bufio.Reader) (*Handshake, error) { +func Read(r io.Reader) (*Handshake, error) { lengthBuf := make([]byte, 1) _, err := io.ReadFull(r, lengthBuf) if err != nil { diff --git a/handshake/handshake_test.go b/handshake/handshake_test.go index e2bc237..054f055 100644 --- a/handshake/handshake_test.go +++ b/handshake/handshake_test.go @@ -1,7 +1,6 @@ package handshake import ( - "bufio" "bytes" "testing" @@ -74,7 +73,7 @@ func TestRead(t *testing.T) { } for _, test := range tests { - reader := bufio.NewReader(bytes.NewReader(test.input)) + reader := bytes.NewReader(test.input) m, err := Read(reader) if test.fails { assert.NotNil(t, err) diff --git a/message/message.go b/message/message.go index ff93662..c05f67d 100644 --- a/message/message.go +++ b/message/message.go @@ -1,7 +1,6 @@ package message import ( - "bufio" "encoding/binary" "fmt" "io" @@ -104,7 +103,7 @@ func (m *Message) Serialize() []byte { } // Read parses a message from a stream. Returns `nil` on keep-alive message -func Read(r *bufio.Reader) (*Message, error) { +func Read(r io.Reader) (*Message, error) { lengthBuf := make([]byte, 4) _, err := io.ReadFull(r, lengthBuf) if err != nil { diff --git a/message/message_test.go b/message/message_test.go index 8f30fff..b7cd363 100644 --- a/message/message_test.go +++ b/message/message_test.go @@ -1,7 +1,6 @@ package message import ( - "bufio" "bytes" "testing" @@ -228,7 +227,7 @@ func TestRead(t *testing.T) { } for _, test := range tests { - reader := bufio.NewReader(bytes.NewReader(test.input)) + reader := bytes.NewReader(test.input) m, err := Read(reader) if test.fails { assert.NotNil(t, err) diff --git a/p2p/p2p.go b/p2p/p2p.go index 2612823..32bbc6b 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -82,20 +82,6 @@ func (state *pieceProgress) readMessage() error { return nil } -func (state *pieceProgress) readMessages() error { - err := state.readMessage() - if err != nil { - return err - } - for state.client.HasNext() { - err := state.readMessage() - if err != nil { - return err - } - } - return nil -} - func attemptDownloadPiece(c *client.Client, pw *pieceWork) ([]byte, error) { state := pieceProgress{ index: pw.index, @@ -127,8 +113,7 @@ func attemptDownloadPiece(c *client.Client, pw *pieceWork) ([]byte, error) { } } - // Wait until we receive at least one message, and consume them - err := state.readMessages() + err := state.readMessage() if err != nil { return nil, err }