You've already forked torrent-client
mirror of
https://github.com/veggiedefender/torrent-client.git
synced 2025-11-06 09:29:16 +02:00
Refactor p2p
This commit is contained in:
@@ -20,7 +20,6 @@ type client struct {
|
||||
reader *bufio.Reader
|
||||
bitfield message.Bitfield
|
||||
choked bool
|
||||
engaged bool
|
||||
}
|
||||
|
||||
func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte) (*handshake.Handshake, error) {
|
||||
@@ -66,11 +65,13 @@ func newClient(peer Peer, peerID, infoHash [20]byte) (*client, error) {
|
||||
|
||||
_, err = completeHandshake(conn, reader, infoHash, peerID)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bf, err := recvBitfield(conn, reader)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -82,7 +83,6 @@ func newClient(peer Peer, peerID, infoHash [20]byte) (*client, error) {
|
||||
reader: reader,
|
||||
bitfield: bf,
|
||||
choked: true,
|
||||
engaged: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
248
p2p/p2p.go
248
p2p/p2p.go
@@ -6,13 +6,13 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"runtime"
|
||||
|
||||
"github.com/veggiedefender/torrent-client/message"
|
||||
)
|
||||
|
||||
const maxBlockSize = 16384
|
||||
const queueSize = 5
|
||||
const maxUnfulfilled = 5
|
||||
|
||||
// Peer encodes connection information for a peer
|
||||
type Peer struct {
|
||||
@@ -20,8 +20,8 @@ type Peer struct {
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// Download holds data required to download a torrent from a list of peers
|
||||
type Download struct {
|
||||
// Torrent holds data required to download a torrent from a list of peers
|
||||
type Torrent struct {
|
||||
Peers []Peer
|
||||
PeerID [20]byte
|
||||
InfoHash [20]byte
|
||||
@@ -30,77 +30,14 @@ type Download struct {
|
||||
}
|
||||
|
||||
type pieceWork struct {
|
||||
index int
|
||||
hash [20]byte
|
||||
length int
|
||||
}
|
||||
|
||||
type pieceResult struct {
|
||||
index int
|
||||
hash [20]byte
|
||||
}
|
||||
|
||||
type swarm struct {
|
||||
clients []*client
|
||||
queue chan *pieceWork
|
||||
buf []byte
|
||||
piecesDone int
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// Download downloads a torrent
|
||||
func (d *Download) Download() ([]byte, error) {
|
||||
clients := d.initClients()
|
||||
if len(clients) == 0 {
|
||||
return nil, fmt.Errorf("Could not connect to any of %d clients", len(d.Peers))
|
||||
}
|
||||
log.Printf("Connected to %d clients\n", len(clients))
|
||||
|
||||
queue := make(chan *pieceWork, len(d.PieceHashes))
|
||||
for index, hash := range d.PieceHashes {
|
||||
queue <- &pieceWork{index, hash}
|
||||
}
|
||||
return d.processQueue(clients, queue), nil
|
||||
}
|
||||
|
||||
func (d *Download) initClients() []*client {
|
||||
// Create clients in parallel
|
||||
c := make(chan *client)
|
||||
for _, p := range d.Peers {
|
||||
go func(p Peer) {
|
||||
client, err := newClient(p, d.PeerID, d.InfoHash)
|
||||
if err != nil {
|
||||
c <- nil
|
||||
} else {
|
||||
c <- client
|
||||
}
|
||||
}(p)
|
||||
}
|
||||
|
||||
// Gather clients into a slice
|
||||
clients := make([]*client, 0)
|
||||
for range d.Peers {
|
||||
client := <-c
|
||||
if client != nil {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
}
|
||||
return clients
|
||||
}
|
||||
|
||||
func (s *swarm) selectClient(index int) (*client, error) {
|
||||
for _, c := range s.clients {
|
||||
if !c.engaged && !c.choked && c.hasPiece(index) {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
for _, c := range s.clients {
|
||||
if !c.engaged && c.hasPiece(index) {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Could not find client for piece %d", index)
|
||||
}
|
||||
|
||||
func calculateBoundsForPiece(index, numPieces, length int) (begin int, end int) {
|
||||
pieceLength := length / numPieces
|
||||
begin = index * pieceLength
|
||||
end = begin + pieceLength
|
||||
return begin, end
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func checkIntegrity(pw *pieceWork, buf []byte) error {
|
||||
@@ -111,13 +48,11 @@ func checkIntegrity(pw *pieceWork, buf []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
|
||||
buf := make([]byte, pieceLength)
|
||||
c.unchoke()
|
||||
c.interested()
|
||||
func attemptDownloadPiece(c *client, pw *pieceWork) ([]byte, error) {
|
||||
buf := make([]byte, pw.length)
|
||||
downloaded := 0
|
||||
requested := 0
|
||||
for downloaded < pieceLength {
|
||||
for downloaded < len(buf) {
|
||||
for c.hasNext() {
|
||||
msg, err := c.read() // this call blocks
|
||||
if err != nil {
|
||||
@@ -127,7 +62,7 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
|
||||
continue
|
||||
}
|
||||
if msg.ID != message.MsgPiece {
|
||||
fmt.Println(msg)
|
||||
log.Println(msg)
|
||||
}
|
||||
switch msg.ID {
|
||||
case message.MsgUnchoke:
|
||||
@@ -141,7 +76,6 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
|
||||
}
|
||||
c.bitfield.SetPiece(index)
|
||||
case message.MsgPiece:
|
||||
// fmt.Println(" PIECE")
|
||||
n, err := message.ParsePiece(pw.index, buf, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,14 +84,13 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if !c.choked && requested < pieceLength && requested-downloaded <= queueSize+1 {
|
||||
for i := 0; i < queueSize; i++ {
|
||||
if !c.choked && requested < len(buf) && requested-downloaded <= maxUnfulfilled+1 {
|
||||
for i := 0; i < maxUnfulfilled; i++ {
|
||||
blockSize := maxBlockSize
|
||||
if pieceLength-requested < blockSize {
|
||||
if len(buf)-requested < blockSize {
|
||||
// Last block might be shorter than the typical block
|
||||
blockSize = pieceLength - requested
|
||||
blockSize = len(buf) - requested
|
||||
}
|
||||
// fmt.Println("Request")
|
||||
c.request(pw.index, requested, blockSize)
|
||||
requested += blockSize
|
||||
}
|
||||
@@ -171,7 +104,7 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
|
||||
continue
|
||||
}
|
||||
if msg.ID != message.MsgPiece {
|
||||
fmt.Println(msg)
|
||||
log.Println(msg)
|
||||
}
|
||||
switch msg.ID {
|
||||
case message.MsgUnchoke:
|
||||
@@ -185,7 +118,6 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
|
||||
}
|
||||
c.bitfield.SetPiece(index)
|
||||
case message.MsgPiece:
|
||||
// fmt.Println(" PIECE")
|
||||
n, err := message.ParsePiece(pw.index, buf, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -193,95 +125,81 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
|
||||
downloaded += n
|
||||
}
|
||||
}
|
||||
|
||||
c.have(pw.index)
|
||||
c.notInterested()
|
||||
|
||||
err := checkIntegrity(pw, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (s *swarm) removeClient(c *client) {
|
||||
if len(s.clients) == 1 {
|
||||
panic("Removed last client")
|
||||
func (t *Torrent) downloadWorker(peer Peer, workQueue chan *pieceWork, results chan *pieceResult) {
|
||||
c, err := newClient(peer, t.PeerID, t.InfoHash)
|
||||
if err != nil {
|
||||
log.Printf("Peer %s unresponsive. Disconnecting\n", peer.IP)
|
||||
return
|
||||
}
|
||||
log.Printf("Removing client. %d clients remaining\n", len(s.clients))
|
||||
s.mux.Lock()
|
||||
c.conn.Close()
|
||||
var i int
|
||||
for i = 0; i < len(s.clients); i++ {
|
||||
if s.clients[i] == c {
|
||||
break
|
||||
}
|
||||
}
|
||||
s.clients = append(s.clients[:i], s.clients[i+1:]...)
|
||||
s.mux.Unlock()
|
||||
}
|
||||
defer c.conn.Close()
|
||||
|
||||
func (s *swarm) worker(d *Download, wg *sync.WaitGroup) {
|
||||
for pw := range s.queue {
|
||||
s.mux.Lock()
|
||||
c, err := s.selectClient(pw.index)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
// Re-enqueue the piece to try again
|
||||
s.queue <- pw
|
||||
s.mux.Unlock()
|
||||
c.unchoke()
|
||||
c.interested()
|
||||
|
||||
for pw := range workQueue {
|
||||
if !c.hasPiece(pw.index) {
|
||||
workQueue <- pw // Put piece back on the queue
|
||||
continue
|
||||
}
|
||||
c.engaged = true
|
||||
s.mux.Unlock()
|
||||
|
||||
begin, end := calculateBoundsForPiece(pw.index, len(d.PieceHashes), d.Length)
|
||||
pieceLength := end - begin
|
||||
pieceBuf, err := downloadPiece(c, pw, pieceLength)
|
||||
// Download the piece
|
||||
buf, err := attemptDownloadPiece(c, pw)
|
||||
if err != nil {
|
||||
// Re-enqueue the piece to try again
|
||||
log.Println(err)
|
||||
s.removeClient(c)
|
||||
s.queue <- pw
|
||||
} else {
|
||||
// Copy into buffer should not overlap with other workers
|
||||
copy(s.buf[begin:end], pieceBuf)
|
||||
|
||||
s.mux.Lock()
|
||||
s.piecesDone++
|
||||
log.Printf("Downloaded piece %d (%d/%d) %0.2f%%\n", pw.index, s.piecesDone, len(d.PieceHashes), float64(s.piecesDone)/float64(len(d.PieceHashes))*100)
|
||||
s.mux.Unlock()
|
||||
|
||||
if s.piecesDone == len(d.PieceHashes) {
|
||||
close(s.queue)
|
||||
}
|
||||
log.Println("Exiting", err)
|
||||
workQueue <- pw // Put piece back on the queue
|
||||
return
|
||||
}
|
||||
|
||||
s.mux.Lock()
|
||||
c.engaged = false
|
||||
s.mux.Unlock()
|
||||
err = checkIntegrity(pw, buf)
|
||||
if err != nil {
|
||||
log.Printf("Piece #%d failed integrity check\n", pw.index)
|
||||
workQueue <- pw // Put piece back on the queue
|
||||
continue
|
||||
}
|
||||
|
||||
results <- &pieceResult{pw.index, buf}
|
||||
c.have(pw.index)
|
||||
}
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
func (d *Download) processQueue(clients []*client, queue chan *pieceWork) []byte {
|
||||
s := swarm{
|
||||
clients: clients,
|
||||
queue: queue,
|
||||
buf: make([]byte, d.Length),
|
||||
mux: sync.Mutex{},
|
||||
}
|
||||
|
||||
numWorkers := (len(s.clients) + 1) / 2
|
||||
log.Printf("Spawning %d workers\n", numWorkers)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(numWorkers)
|
||||
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
go s.worker(d, &wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return s.buf
|
||||
func calculateBoundsForPiece(index, numPieces, length int) (begin int, end int) {
|
||||
pieceLength := length / numPieces
|
||||
begin = index * pieceLength
|
||||
end = begin + pieceLength
|
||||
return begin, end
|
||||
}
|
||||
|
||||
// Download downloads the torrent
|
||||
func (t *Torrent) Download() ([]byte, error) {
|
||||
// Init queues for workers to retrieve work and send results
|
||||
workQueue := make(chan *pieceWork, len(t.PieceHashes))
|
||||
results := make(chan *pieceResult, len(t.PieceHashes))
|
||||
for index, hash := range t.PieceHashes {
|
||||
length := t.Length / len(t.PieceHashes)
|
||||
workQueue <- &pieceWork{index, hash, length}
|
||||
}
|
||||
|
||||
// Start workers
|
||||
for _, peer := range t.Peers {
|
||||
go t.downloadWorker(peer, workQueue, results)
|
||||
}
|
||||
|
||||
// Collect results into a buffer until full
|
||||
buf := make([]byte, t.Length)
|
||||
donePieces := 0
|
||||
for donePieces < len(t.PieceHashes) {
|
||||
res := <-results
|
||||
begin, end := calculateBoundsForPiece(res.index, len(t.PieceHashes), t.Length)
|
||||
copy(buf[begin:end], res.buf)
|
||||
donePieces++
|
||||
|
||||
percent := float64(donePieces) / float64(len(t.PieceHashes)) * 100
|
||||
log.Printf("(%0.2f%%) Downloaded piece #%d with %d goroutines\n", percent, res.index, runtime.NumGoroutine())
|
||||
}
|
||||
close(workQueue)
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user