1
0
mirror of https://github.com/veggiedefender/torrent-client.git synced 2025-11-06 09:29:16 +02:00

Rip out reader code

This commit is contained in:
Jesse Li
2020-01-02 19:36:25 -05:00
parent e9edc62c57
commit b59125b5b7
6 changed files with 12 additions and 40 deletions

View File

@@ -1,7 +1,6 @@
package client package client
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"net" "net"
@@ -24,10 +23,9 @@ type Client struct {
peer peers.Peer peer peers.Peer
infoHash [20]byte infoHash [20]byte
peerID [20]byte peerID [20]byte
reader *bufio.Reader
} }
func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte) (*handshake.Handshake, error) { func completeHandshake(conn net.Conn, infohash, peerID [20]byte) (*handshake.Handshake, error) {
conn.SetDeadline(time.Now().Add(3 * time.Second)) conn.SetDeadline(time.Now().Add(3 * time.Second))
defer conn.SetDeadline(time.Time{}) // Disable the deadline defer conn.SetDeadline(time.Time{}) // Disable the deadline
@@ -37,7 +35,7 @@ func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte
return nil, err return nil, err
} }
res, err := handshake.Read(r) res, err := handshake.Read(conn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -47,11 +45,11 @@ func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte
return res, nil return res, nil
} }
func recvBitfield(conn net.Conn, r *bufio.Reader) (bitfield.Bitfield, error) { func recvBitfield(conn net.Conn) (bitfield.Bitfield, error) {
conn.SetDeadline(time.Now().Add(5 * time.Second)) conn.SetDeadline(time.Now().Add(5 * time.Second))
defer conn.SetDeadline(time.Time{}) // Disable the deadline defer conn.SetDeadline(time.Time{}) // Disable the deadline
msg, err := message.Read(r) msg, err := message.Read(conn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -71,15 +69,14 @@ func New(peer peers.Peer, peerID, infoHash [20]byte) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
reader := bufio.NewReader(conn)
_, err = completeHandshake(conn, reader, infoHash, peerID) _, err = completeHandshake(conn, infoHash, peerID)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }
bf, err := recvBitfield(conn, reader) bf, err := recvBitfield(conn)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
@@ -92,18 +89,12 @@ func New(peer peers.Peer, peerID, infoHash [20]byte) (*Client, error) {
peer: peer, peer: peer,
infoHash: infoHash, infoHash: infoHash,
peerID: peerID, peerID: peerID,
reader: reader,
}, nil }, nil
} }
// HasNext returns true if there are unread messages from the peer
func (c *Client) HasNext() bool {
return c.reader.Buffered() > 0
}
// Read reads and consumes a message from the connection // Read reads and consumes a message from the connection
func (c *Client) Read() (*message.Message, error) { func (c *Client) Read() (*message.Message, error) {
msg, err := message.Read(c.reader) msg, err := message.Read(c.Conn)
return msg, err return msg, err
} }

View File

@@ -1,7 +1,6 @@
package handshake package handshake
import ( import (
"bufio"
"fmt" "fmt"
"io" "io"
) )
@@ -36,7 +35,7 @@ func (h *Handshake) Serialize() []byte {
} }
// Read parses a handshake from a stream // Read parses a handshake from a stream
func Read(r *bufio.Reader) (*Handshake, error) { func Read(r io.Reader) (*Handshake, error) {
lengthBuf := make([]byte, 1) lengthBuf := make([]byte, 1)
_, err := io.ReadFull(r, lengthBuf) _, err := io.ReadFull(r, lengthBuf)
if err != nil { if err != nil {

View File

@@ -1,7 +1,6 @@
package handshake package handshake
import ( import (
"bufio"
"bytes" "bytes"
"testing" "testing"
@@ -74,7 +73,7 @@ func TestRead(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
reader := bufio.NewReader(bytes.NewReader(test.input)) reader := bytes.NewReader(test.input)
m, err := Read(reader) m, err := Read(reader)
if test.fails { if test.fails {
assert.NotNil(t, err) assert.NotNil(t, err)

View File

@@ -1,7 +1,6 @@
package message package message
import ( import (
"bufio"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
@@ -104,7 +103,7 @@ func (m *Message) Serialize() []byte {
} }
// Read parses a message from a stream. Returns `nil` on keep-alive message // Read parses a message from a stream. Returns `nil` on keep-alive message
func Read(r *bufio.Reader) (*Message, error) { func Read(r io.Reader) (*Message, error) {
lengthBuf := make([]byte, 4) lengthBuf := make([]byte, 4)
_, err := io.ReadFull(r, lengthBuf) _, err := io.ReadFull(r, lengthBuf)
if err != nil { if err != nil {

View File

@@ -1,7 +1,6 @@
package message package message
import ( import (
"bufio"
"bytes" "bytes"
"testing" "testing"
@@ -228,7 +227,7 @@ func TestRead(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
reader := bufio.NewReader(bytes.NewReader(test.input)) reader := bytes.NewReader(test.input)
m, err := Read(reader) m, err := Read(reader)
if test.fails { if test.fails {
assert.NotNil(t, err) assert.NotNil(t, err)

View File

@@ -82,20 +82,6 @@ func (state *pieceProgress) readMessage() error {
return nil return nil
} }
func (state *pieceProgress) readMessages() error {
err := state.readMessage()
if err != nil {
return err
}
for state.client.HasNext() {
err := state.readMessage()
if err != nil {
return err
}
}
return nil
}
func attemptDownloadPiece(c *client.Client, pw *pieceWork) ([]byte, error) { func attemptDownloadPiece(c *client.Client, pw *pieceWork) ([]byte, error) {
state := pieceProgress{ state := pieceProgress{
index: pw.index, index: pw.index,
@@ -127,8 +113,7 @@ func attemptDownloadPiece(c *client.Client, pw *pieceWork) ([]byte, error) {
} }
} }
// Wait until we receive at least one message, and consume them err := state.readMessage()
err := state.readMessages()
if err != nil { if err != nil {
return nil, err return nil, err
} }