1
0
mirror of https://github.com/veggiedefender/torrent-client.git synced 2025-11-06 09:29:16 +02:00

Incredibly messy and slow download

This commit is contained in:
Jesse
2019-12-26 21:53:11 -05:00
parent b1fc8c7fb8
commit 7e8cac2d3e
8 changed files with 290 additions and 46 deletions

View File

@@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"bufio"
"errors" "errors"
"io" "io"
) )
@@ -35,7 +36,7 @@ func (h *Handshake) Serialize() []byte {
} }
// Read parses a message from a stream. Returns `nil` on keep-alive message // Read parses a message from a stream. Returns `nil` on keep-alive message
func Read(r io.Reader) (*Handshake, error) { func Read(r *bufio.Reader) (*Handshake, error) {
lengthBuf := make([]byte, 1) lengthBuf := make([]byte, 1)
_, err := io.ReadFull(r, lengthBuf) _, err := io.ReadFull(r, lengthBuf)
if err != nil { if err != nil {

View File

@@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"bufio"
"bytes" "bytes"
"testing" "testing"
@@ -73,7 +74,7 @@ func TestRead(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
reader := bytes.NewReader(test.input) reader := bufio.NewReader(bytes.NewReader(test.input))
m, err := Read(reader) m, err := Read(reader)
if test.fails { if test.fails {
assert.NotNil(t, err) assert.NotNil(t, err)

27
main.go
View File

@@ -8,17 +8,30 @@ import (
) )
func main() { func main() {
file, err := os.Open(os.Args[1]) inPath := os.Args[1]
if err != nil { outPath := os.Args[2]
log.Fatal(err)
}
defer file.Close()
t, err := torrent.Open(file) inFile, err := os.Open(inPath)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
err = t.Download() defer inFile.Close()
t, err := torrent.Open(inFile)
if err != nil {
log.Fatal(err)
}
buf, err := t.Download()
if err != nil {
log.Fatal(err)
}
outFile, err := os.Create(outPath)
if err != nil {
log.Fatal(err)
}
defer outFile.Close()
_, err = outFile.Write(buf)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@@ -1,6 +1,7 @@
package message package message
import ( import (
"bufio"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@@ -84,7 +85,7 @@ func (m *Message) Serialize() []byte {
} }
// Read parses a message from a stream. Returns `nil` on keep-alive message // Read parses a message from a stream. Returns `nil` on keep-alive message
func Read(r io.Reader) (*Message, error) { func Read(r *bufio.Reader) (*Message, error) {
lengthBuf := make([]byte, 4) lengthBuf := make([]byte, 4)
_, err := io.ReadFull(r, lengthBuf) _, err := io.ReadFull(r, lengthBuf)
if err != nil { if err != nil {
@@ -115,7 +116,6 @@ func (m *Message) String() string {
if m == nil { if m == nil {
return "KeepAlive" return "KeepAlive"
} }
switch m.ID { switch m.ID {
case MsgChoke: case MsgChoke:
return "Choke" return "Choke"

View File

@@ -1,6 +1,7 @@
package message package message
import ( import (
"bufio"
"bytes" "bytes"
"testing" "testing"
@@ -179,7 +180,7 @@ func TestRead(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
reader := bytes.NewReader(test.input) reader := bufio.NewReader(bytes.NewReader(test.input))
m, err := Read(reader) m, err := Read(reader)
if test.fails { if test.fails {
assert.NotNil(t, err) assert.NotNil(t, err)

View File

@@ -1,10 +1,11 @@
package p2p package p2p
import ( import (
"bufio"
"encoding/binary"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"sync"
"time" "time"
"github.com/veggiedefender/torrent-client/message" "github.com/veggiedefender/torrent-client/message"
@@ -13,13 +14,17 @@ import (
) )
type client struct { type client struct {
peer Peer
infoHash [20]byte
peerID [20]byte
conn net.Conn conn net.Conn
reader *bufio.Reader
bitfield message.Bitfield bitfield message.Bitfield
Choked bool choked bool
Mux sync.Mutex engaged bool
} }
func completeHandshake(conn net.Conn, infohash, peerID [20]byte) (*handshake.Handshake, error) { func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte) (*handshake.Handshake, error) {
conn.SetDeadline(time.Now().Local().Add(3 * time.Second)) conn.SetDeadline(time.Now().Local().Add(3 * time.Second))
defer conn.SetDeadline(time.Time{}) // Disable the deadline defer conn.SetDeadline(time.Time{}) // Disable the deadline
@@ -29,18 +34,18 @@ func completeHandshake(conn net.Conn, infohash, peerID [20]byte) (*handshake.Han
return nil, err return nil, err
} }
res, err := handshake.Read(conn) res, err := handshake.Read(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return res, nil return res, nil
} }
func recvBitfield(conn net.Conn) (message.Bitfield, error) { func recvBitfield(conn net.Conn, r *bufio.Reader) (message.Bitfield, error) {
conn.SetDeadline(time.Now().Local().Add(5 * time.Second)) conn.SetDeadline(time.Now().Local().Add(5 * time.Second))
defer conn.SetDeadline(time.Time{}) // Disable the deadline defer conn.SetDeadline(time.Time{}) // Disable the deadline
msg, err := message.Read(conn) msg, err := message.Read(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,22 +63,72 @@ func newClient(peer Peer, peerID, infoHash [20]byte) (*client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = completeHandshake(conn, infoHash, peerID) reader := bufio.NewReader(conn)
_, err = completeHandshake(conn, reader, infoHash, peerID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
bf, err := recvBitfield(conn)
bf, err := recvBitfield(conn, reader)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &client{ return &client{
peer: peer,
infoHash: infoHash,
peerID: peerID,
conn: conn, conn: conn,
reader: reader,
bitfield: bf, bitfield: bf,
Mux: sync.Mutex{}, choked: true,
Choked: true, engaged: false,
}, nil }, nil
} }
func (c *client) hasPiece(index int) bool { func (c *client) hasPiece(index int) bool {
return c.bitfield.HasPiece(index) return c.bitfield.HasPiece(index)
} }
func (c *client) hasNext() bool {
fmt.Println(c.reader.Buffered() > 0)
return c.reader.Buffered() > 0
}
func (c *client) read() (*message.Message, error) {
msg, err := message.Read(c.reader)
return msg, err
}
func (c *client) request(index, begin, length int) error {
req := message.FormatRequest(index, begin, length)
_, err := c.conn.Write(req.Serialize())
return err
}
func (c *client) interested() error {
msg := message.Message{ID: message.MsgInterested}
_, err := c.conn.Write(msg.Serialize())
return err
}
func (c *client) notInterested() error {
msg := message.Message{ID: message.MsgNotInterested}
_, err := c.conn.Write(msg.Serialize())
return err
}
func (c *client) unchoke() error {
msg := message.Message{ID: message.MsgUnchoke}
_, err := c.conn.Write(msg.Serialize())
return err
}
func (c *client) have(index int) error {
pl := make([]byte, 4)
binary.BigEndian.PutUint32(pl, uint32(index))
msg := message.Message{ID: message.MsgHave, Payload: pl}
_, err := c.conn.Write(msg.Serialize())
return err
}

View File

@@ -1,11 +1,18 @@
package p2p package p2p
import ( import (
"bytes"
"crypto/sha1"
"fmt" "fmt"
"log"
"net" "net"
"sync" "sync"
"github.com/veggiedefender/torrent-client/message"
) )
const maxBlockSize = 32768
// Peer encodes connection information for a peer // Peer encodes connection information for a peer
type Peer struct { type Peer struct {
IP net.IP IP net.IP
@@ -29,23 +36,24 @@ type pieceWork struct {
type swarm struct { type swarm struct {
clients []*client clients []*client
queue chan *pieceWork queue chan *pieceWork
buf []byte
piecesDone int
mux sync.Mutex mux sync.Mutex
} }
// Download downloads a torrent // Download downloads a torrent
func (d *Download) Download() error { func (d *Download) Download() ([]byte, error) {
clients := d.initClients() clients := d.initClients()
if len(clients) == 0 { if len(clients) == 0 {
return fmt.Errorf("Could not connect to any of %d clients", len(d.Peers)) return nil, fmt.Errorf("Could not connect to any of %d clients", len(d.Peers))
} }
log.Printf("Connected to %d clients\n", len(clients))
queue := make(chan *pieceWork, len(d.PieceHashes)) queue := make(chan *pieceWork, len(d.PieceHashes))
for index, hash := range d.PieceHashes { for index, hash := range d.PieceHashes {
queue <- &pieceWork{index, hash} queue <- &pieceWork{index, hash}
} }
processQueue(clients, queue) return d.processQueue(clients, queue), nil
return nil
} }
func (d *Download) initClients() []*client { func (d *Download) initClients() []*client {
@@ -62,6 +70,7 @@ func (d *Download) initClients() []*client {
}(p) }(p)
} }
// Gather clients into a slice
clients := make([]*client, 0) clients := make([]*client, 0)
for range d.Peers { for range d.Peers {
client := <-c client := <-c
@@ -72,15 +81,177 @@ func (d *Download) initClients() []*client {
return clients return clients
} }
func (s *swarm) selectClient(index int) *client { func (s *swarm) selectClient(index int) (*client, error) {
return s.clients[0] for _, c := range s.clients {
if !c.engaged && !c.choked && c.hasPiece(index) {
return c, nil
}
}
for _, c := range s.clients {
if !c.engaged && c.hasPiece(index) {
return c, nil
}
}
return nil, fmt.Errorf("Could not find client for piece %d", index)
} }
func processQueue(clients []*client, queue chan *pieceWork) { func calculateBoundsForPiece(index, numPieces, length int) (int, int) {
s := swarm{clients, queue, sync.Mutex{}} pieceLength := length / numPieces
for pw := range s.queue { begin := index * pieceLength
client := s.selectClient(pw.index) end := begin + pieceLength
fmt.Println(client.conn.RemoteAddr()) if end > length-1 {
end = length - 1
}
return begin, end
}
func checkIntegrity(pw *pieceWork, buf []byte) error {
hash := sha1.Sum(buf)
if !bytes.Equal(hash[:], pw.hash[:]) {
return fmt.Errorf("Index %d failed integrity check", pw.index)
}
return nil
}
func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
buf := make([]byte, pieceLength)
c.unchoke()
c.interested()
offset := 0
for c.hasNext() {
msg, err := c.read() // this call blocks
if err != nil {
return nil, err
}
if msg == nil { // keep-alive
continue
}
fmt.Println("CATCHING UP ON", msg)
switch msg.ID {
case message.MsgUnchoke:
c.choked = false
case message.MsgChoke:
c.choked = true
}
}
for offset < pieceLength {
if !c.choked {
blockSize := maxBlockSize
if pieceLength-offset < blockSize {
// Last block might be shorter than the typical block
blockSize = pieceLength - offset
}
c.request(pw.index, offset, blockSize)
}
msg, err := c.read() // this call blocks
if err != nil {
return nil, err
}
if msg == nil { // keep-alive
continue
}
if msg.ID != message.MsgPiece {
fmt.Println(msg)
}
switch msg.ID {
case message.MsgUnchoke:
c.choked = false
case message.MsgChoke:
c.choked = true
case message.MsgPiece:
n, err := message.ParsePiece(pw.index, buf, msg)
if err != nil {
return nil, err
}
offset += n
}
}
c.have(pw.index)
c.notInterested()
err := checkIntegrity(pw, buf)
if err != nil {
return nil, err
}
return buf, nil
}
func (s *swarm) removeClient(c *client) {
log.Printf("Removing client. %d clients remaining\n", len(s.clients))
s.mux.Lock()
var i int
for i = 0; i < len(s.clients); i++ {
if s.clients[i] == c {
break break
} }
} }
s.clients = append(s.clients[:i], s.clients[i+1:]...)
s.mux.Unlock()
}
func (s *swarm) worker(d *Download, wg *sync.WaitGroup) {
for pw := range s.queue {
s.mux.Lock()
c, err := s.selectClient(pw.index)
if err != nil {
// Re-enqueue the piece to try again
s.queue <- pw
s.mux.Unlock()
continue
}
c.engaged = true
s.mux.Unlock()
begin, end := calculateBoundsForPiece(pw.index, len(d.PieceHashes), d.Length)
pieceLength := end - begin
pieceBuf, err := downloadPiece(c, pw, pieceLength)
if err != nil {
// Re-enqueue the piece to try again
log.Println(err)
s.removeClient(c)
s.queue <- pw
} else {
// Copy into buffer should not overlap with other workers
copy(s.buf[begin:end], pieceBuf)
s.mux.Lock()
s.piecesDone++
log.Printf("Downloaded piece %d (%d/%d) %0.2f%%\n", pw.index, s.piecesDone, len(d.PieceHashes), float64(s.piecesDone)/float64(len(d.PieceHashes)))
s.mux.Unlock()
if s.piecesDone == len(d.PieceHashes) {
close(s.queue)
}
}
s.mux.Lock()
c.engaged = false
s.mux.Unlock()
}
wg.Done()
}
func (d *Download) processQueue(clients []*client, queue chan *pieceWork) []byte {
s := swarm{
clients: clients,
queue: queue,
buf: make([]byte, d.Length),
mux: sync.Mutex{},
}
numWorkers := len(s.clients) / 2
log.Printf("Spawning %d workers\n", numWorkers)
wg := sync.WaitGroup{}
wg.Add(numWorkers)
for i := 0; i < numWorkers; i++ {
go s.worker(d, &wg)
}
wg.Wait()
return s.buf
}

View File

@@ -6,7 +6,6 @@ import (
"crypto/sha1" "crypto/sha1"
"fmt" "fmt"
"io" "io"
"net"
"github.com/jackpal/bencode-go" "github.com/jackpal/bencode-go"
"github.com/veggiedefender/torrent-client/p2p" "github.com/veggiedefender/torrent-client/p2p"
@@ -38,15 +37,15 @@ type bencodeTorrent struct {
} }
// Download downloads a torrent // Download downloads a torrent
func (t *Torrent) Download() error { func (t *Torrent) Download() ([]byte, error) {
var peerID [20]byte var peerID [20]byte
_, err := rand.Read(peerID[:]) _, err := rand.Read(peerID[:])
if err != nil { if err != nil {
return err return nil, err
} }
// peers, err := t.getPeers(peerID, Port) peers, err := t.getPeers(peerID, Port)
peers := []p2p.Peer{{IP: net.IP{127, 0, 0, 1}, Port: 51413}} // peers := []p2p.Peer{{IP: net.IP{127, 0, 0, 1}, Port: 51413}}
downloader := p2p.Download{ downloader := p2p.Download{
Peers: peers, Peers: peers,
PeerID: peerID, PeerID: peerID,
@@ -54,8 +53,11 @@ func (t *Torrent) Download() error {
PieceHashes: t.PieceHashes, PieceHashes: t.PieceHashes,
Length: t.Length, Length: t.Length,
} }
err = downloader.Download() buf, err := downloader.Download()
return err if err != nil {
return nil, err
}
return buf, nil
} }
// Open parses a torrent file // Open parses a torrent file