1
0
mirror of https://github.com/veggiedefender/torrent-client.git synced 2025-11-06 09:29:16 +02:00

Rip out reader code

This commit is contained in:
Jesse Li
2020-01-02 19:36:25 -05:00
parent e9edc62c57
commit b59125b5b7
6 changed files with 12 additions and 40 deletions

View File

@@ -1,7 +1,6 @@
package client
import (
"bufio"
"bytes"
"fmt"
"net"
@@ -24,10 +23,9 @@ type Client struct {
peer peers.Peer
infoHash [20]byte
peerID [20]byte
reader *bufio.Reader
}
func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte) (*handshake.Handshake, error) {
func completeHandshake(conn net.Conn, infohash, peerID [20]byte) (*handshake.Handshake, error) {
conn.SetDeadline(time.Now().Add(3 * time.Second))
defer conn.SetDeadline(time.Time{}) // Disable the deadline
@@ -37,7 +35,7 @@ func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte
return nil, err
}
res, err := handshake.Read(r)
res, err := handshake.Read(conn)
if err != nil {
return nil, err
}
@@ -47,11 +45,11 @@ func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte
return res, nil
}
func recvBitfield(conn net.Conn, r *bufio.Reader) (bitfield.Bitfield, error) {
func recvBitfield(conn net.Conn) (bitfield.Bitfield, error) {
conn.SetDeadline(time.Now().Add(5 * time.Second))
defer conn.SetDeadline(time.Time{}) // Disable the deadline
msg, err := message.Read(r)
msg, err := message.Read(conn)
if err != nil {
return nil, err
}
@@ -71,15 +69,14 @@ func New(peer peers.Peer, peerID, infoHash [20]byte) (*Client, error) {
if err != nil {
return nil, err
}
reader := bufio.NewReader(conn)
_, err = completeHandshake(conn, reader, infoHash, peerID)
_, err = completeHandshake(conn, infoHash, peerID)
if err != nil {
conn.Close()
return nil, err
}
bf, err := recvBitfield(conn, reader)
bf, err := recvBitfield(conn)
if err != nil {
conn.Close()
return nil, err
@@ -92,18 +89,12 @@ func New(peer peers.Peer, peerID, infoHash [20]byte) (*Client, error) {
peer: peer,
infoHash: infoHash,
peerID: peerID,
reader: reader,
}, nil
}
// HasNext returns true if there are unread messages from the peer
func (c *Client) HasNext() bool {
return c.reader.Buffered() > 0
}
// Read reads and consumes a message from the connection
func (c *Client) Read() (*message.Message, error) {
msg, err := message.Read(c.reader)
msg, err := message.Read(c.Conn)
return msg, err
}

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"fmt"
"io"
)
@@ -36,7 +35,7 @@ func (h *Handshake) Serialize() []byte {
}
// Read parses a handshake from a stream
func Read(r *bufio.Reader) (*Handshake, error) {
func Read(r io.Reader) (*Handshake, error) {
lengthBuf := make([]byte, 1)
_, err := io.ReadFull(r, lengthBuf)
if err != nil {

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bufio"
"bytes"
"testing"
@@ -74,7 +73,7 @@ func TestRead(t *testing.T) {
}
for _, test := range tests {
reader := bufio.NewReader(bytes.NewReader(test.input))
reader := bytes.NewReader(test.input)
m, err := Read(reader)
if test.fails {
assert.NotNil(t, err)

View File

@@ -1,7 +1,6 @@
package message
import (
"bufio"
"encoding/binary"
"fmt"
"io"
@@ -104,7 +103,7 @@ func (m *Message) Serialize() []byte {
}
// Read parses a message from a stream. Returns `nil` on keep-alive message
func Read(r *bufio.Reader) (*Message, error) {
func Read(r io.Reader) (*Message, error) {
lengthBuf := make([]byte, 4)
_, err := io.ReadFull(r, lengthBuf)
if err != nil {

View File

@@ -1,7 +1,6 @@
package message
import (
"bufio"
"bytes"
"testing"
@@ -228,7 +227,7 @@ func TestRead(t *testing.T) {
}
for _, test := range tests {
reader := bufio.NewReader(bytes.NewReader(test.input))
reader := bytes.NewReader(test.input)
m, err := Read(reader)
if test.fails {
assert.NotNil(t, err)

View File

@@ -82,20 +82,6 @@ func (state *pieceProgress) readMessage() error {
return nil
}
func (state *pieceProgress) readMessages() error {
err := state.readMessage()
if err != nil {
return err
}
for state.client.HasNext() {
err := state.readMessage()
if err != nil {
return err
}
}
return nil
}
func attemptDownloadPiece(c *client.Client, pw *pieceWork) ([]byte, error) {
state := pieceProgress{
index: pw.index,
@@ -127,8 +113,7 @@ func attemptDownloadPiece(c *client.Client, pw *pieceWork) ([]byte, error) {
}
}
// Wait until we receive at least one message, and consume them
err := state.readMessages()
err := state.readMessage()
if err != nil {
return nil, err
}