1
0
mirror of https://github.com/veggiedefender/torrent-client.git synced 2025-11-06 09:29:16 +02:00

Refactor p2p

This commit is contained in:
Jesse
2019-12-29 14:02:50 -05:00
parent d7d46ab7ef
commit 9d7ddbc62c
9 changed files with 123 additions and 198 deletions

View File

@@ -4,7 +4,7 @@ import (
"log"
"os"
"github.com/veggiedefender/torrent-client/torrent"
"github.com/veggiedefender/torrent-client/torrentfile"
)
func main() {
@@ -17,7 +17,7 @@ func main() {
}
defer inFile.Close()
t, err := torrent.Open(inFile)
t, err := torrentfile.Open(inFile)
if err != nil {
log.Fatal(err)
}

View File

@@ -128,7 +128,7 @@ func Read(r *bufio.Reader) (*Message, error) {
return &m, nil
}
func (m *Message) String() string {
func (m *Message) name() string {
if m == nil {
return "KeepAlive"
}
@@ -158,6 +158,13 @@ func (m *Message) String() string {
}
}
func (m *Message) String() string {
if m == nil {
return m.name()
}
return fmt.Sprintf("%s [%d]", m.name(), len(m.Payload))
}
// HasPiece tells if a bitfield has a particular index set
func (bf Bitfield) HasPiece(index int) bool {
byteIndex := index / 8

View File

@@ -245,17 +245,17 @@ func TestString(t *testing.T) {
output string
}{
{nil, "KeepAlive"},
{&Message{MsgChoke, []byte{1, 2, 3}}, "Choke"},
{&Message{MsgUnchoke, []byte{1, 2, 3}}, "Unchoke"},
{&Message{MsgInterested, []byte{1, 2, 3}}, "Interested"},
{&Message{MsgNotInterested, []byte{1, 2, 3}}, "NotInterested"},
{&Message{MsgHave, []byte{1, 2, 3}}, "Have"},
{&Message{MsgBitfield, []byte{1, 2, 3}}, "Bitfield"},
{&Message{MsgRequest, []byte{1, 2, 3}}, "Request"},
{&Message{MsgPiece, []byte{1, 2, 3}}, "Piece"},
{&Message{MsgCancel, []byte{1, 2, 3}}, "Cancel"},
{&Message{MsgPort, []byte{1, 2, 3}}, "Port"},
{&Message{99, []byte{1, 2, 3}}, "Unknown#99"},
{&Message{MsgChoke, []byte{1, 2, 3}}, "Choke [3]"},
{&Message{MsgUnchoke, []byte{1, 2, 3}}, "Unchoke [3]"},
{&Message{MsgInterested, []byte{1, 2, 3}}, "Interested [3]"},
{&Message{MsgNotInterested, []byte{1, 2, 3}}, "NotInterested [3]"},
{&Message{MsgHave, []byte{1, 2, 3}}, "Have [3]"},
{&Message{MsgBitfield, []byte{1, 2, 3}}, "Bitfield [3]"},
{&Message{MsgRequest, []byte{1, 2, 3}}, "Request [3]"},
{&Message{MsgPiece, []byte{1, 2, 3}}, "Piece [3]"},
{&Message{MsgCancel, []byte{1, 2, 3}}, "Cancel [3]"},
{&Message{MsgPort, []byte{1, 2, 3}}, "Port [3]"},
{&Message{99, []byte{1, 2, 3}}, "Unknown#99 [3]"},
}
for _, test := range tests {

View File

@@ -20,7 +20,6 @@ type client struct {
reader *bufio.Reader
bitfield message.Bitfield
choked bool
engaged bool
}
func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte) (*handshake.Handshake, error) {
@@ -66,11 +65,13 @@ func newClient(peer Peer, peerID, infoHash [20]byte) (*client, error) {
_, err = completeHandshake(conn, reader, infoHash, peerID)
if err != nil {
conn.Close()
return nil, err
}
bf, err := recvBitfield(conn, reader)
if err != nil {
conn.Close()
return nil, err
}
@@ -82,7 +83,6 @@ func newClient(peer Peer, peerID, infoHash [20]byte) (*client, error) {
reader: reader,
bitfield: bf,
choked: true,
engaged: false,
}, nil
}

View File

@@ -6,13 +6,13 @@ import (
"fmt"
"log"
"net"
"sync"
"runtime"
"github.com/veggiedefender/torrent-client/message"
)
const maxBlockSize = 16384
const queueSize = 5
const maxUnfulfilled = 5
// Peer encodes connection information for a peer
type Peer struct {
@@ -20,8 +20,8 @@ type Peer struct {
Port uint16
}
// Download holds data required to download a torrent from a list of peers
type Download struct {
// Torrent holds data required to download a torrent from a list of peers
type Torrent struct {
Peers []Peer
PeerID [20]byte
InfoHash [20]byte
@@ -30,77 +30,14 @@ type Download struct {
}
type pieceWork struct {
index int
hash [20]byte
length int
}
type pieceResult struct {
index int
hash [20]byte
}
type swarm struct {
clients []*client
queue chan *pieceWork
buf []byte
piecesDone int
mux sync.Mutex
}
// Download downloads a torrent
func (d *Download) Download() ([]byte, error) {
clients := d.initClients()
if len(clients) == 0 {
return nil, fmt.Errorf("Could not connect to any of %d clients", len(d.Peers))
}
log.Printf("Connected to %d clients\n", len(clients))
queue := make(chan *pieceWork, len(d.PieceHashes))
for index, hash := range d.PieceHashes {
queue <- &pieceWork{index, hash}
}
return d.processQueue(clients, queue), nil
}
func (d *Download) initClients() []*client {
// Create clients in parallel
c := make(chan *client)
for _, p := range d.Peers {
go func(p Peer) {
client, err := newClient(p, d.PeerID, d.InfoHash)
if err != nil {
c <- nil
} else {
c <- client
}
}(p)
}
// Gather clients into a slice
clients := make([]*client, 0)
for range d.Peers {
client := <-c
if client != nil {
clients = append(clients, client)
}
}
return clients
}
func (s *swarm) selectClient(index int) (*client, error) {
for _, c := range s.clients {
if !c.engaged && !c.choked && c.hasPiece(index) {
return c, nil
}
}
for _, c := range s.clients {
if !c.engaged && c.hasPiece(index) {
return c, nil
}
}
return nil, fmt.Errorf("Could not find client for piece %d", index)
}
func calculateBoundsForPiece(index, numPieces, length int) (begin int, end int) {
pieceLength := length / numPieces
begin = index * pieceLength
end = begin + pieceLength
return begin, end
buf []byte
}
func checkIntegrity(pw *pieceWork, buf []byte) error {
@@ -111,13 +48,11 @@ func checkIntegrity(pw *pieceWork, buf []byte) error {
return nil
}
func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
buf := make([]byte, pieceLength)
c.unchoke()
c.interested()
func attemptDownloadPiece(c *client, pw *pieceWork) ([]byte, error) {
buf := make([]byte, pw.length)
downloaded := 0
requested := 0
for downloaded < pieceLength {
for downloaded < len(buf) {
for c.hasNext() {
msg, err := c.read() // this call blocks
if err != nil {
@@ -127,7 +62,7 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
continue
}
if msg.ID != message.MsgPiece {
fmt.Println(msg)
log.Println(msg)
}
switch msg.ID {
case message.MsgUnchoke:
@@ -141,7 +76,6 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
}
c.bitfield.SetPiece(index)
case message.MsgPiece:
// fmt.Println(" PIECE")
n, err := message.ParsePiece(pw.index, buf, msg)
if err != nil {
return nil, err
@@ -150,14 +84,13 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
}
}
if !c.choked && requested < pieceLength && requested-downloaded <= queueSize+1 {
for i := 0; i < queueSize; i++ {
if !c.choked && requested < len(buf) && requested-downloaded <= maxUnfulfilled+1 {
for i := 0; i < maxUnfulfilled; i++ {
blockSize := maxBlockSize
if pieceLength-requested < blockSize {
if len(buf)-requested < blockSize {
// Last block might be shorter than the typical block
blockSize = pieceLength - requested
blockSize = len(buf) - requested
}
// fmt.Println("Request")
c.request(pw.index, requested, blockSize)
requested += blockSize
}
@@ -171,7 +104,7 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
continue
}
if msg.ID != message.MsgPiece {
fmt.Println(msg)
log.Println(msg)
}
switch msg.ID {
case message.MsgUnchoke:
@@ -185,7 +118,6 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
}
c.bitfield.SetPiece(index)
case message.MsgPiece:
// fmt.Println(" PIECE")
n, err := message.ParsePiece(pw.index, buf, msg)
if err != nil {
return nil, err
@@ -193,95 +125,81 @@ func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) {
downloaded += n
}
}
c.have(pw.index)
c.notInterested()
err := checkIntegrity(pw, buf)
if err != nil {
return nil, err
}
return buf, nil
}
func (s *swarm) removeClient(c *client) {
if len(s.clients) == 1 {
panic("Removed last client")
func (t *Torrent) downloadWorker(peer Peer, workQueue chan *pieceWork, results chan *pieceResult) {
c, err := newClient(peer, t.PeerID, t.InfoHash)
if err != nil {
log.Printf("Peer %s unresponsive. Disconnecting\n", peer.IP)
return
}
log.Printf("Removing client. %d clients remaining\n", len(s.clients))
s.mux.Lock()
c.conn.Close()
var i int
for i = 0; i < len(s.clients); i++ {
if s.clients[i] == c {
break
}
}
s.clients = append(s.clients[:i], s.clients[i+1:]...)
s.mux.Unlock()
}
defer c.conn.Close()
func (s *swarm) worker(d *Download, wg *sync.WaitGroup) {
for pw := range s.queue {
s.mux.Lock()
c, err := s.selectClient(pw.index)
if err != nil {
fmt.Println(err)
// Re-enqueue the piece to try again
s.queue <- pw
s.mux.Unlock()
c.unchoke()
c.interested()
for pw := range workQueue {
if !c.hasPiece(pw.index) {
workQueue <- pw // Put piece back on the queue
continue
}
c.engaged = true
s.mux.Unlock()
begin, end := calculateBoundsForPiece(pw.index, len(d.PieceHashes), d.Length)
pieceLength := end - begin
pieceBuf, err := downloadPiece(c, pw, pieceLength)
// Download the piece
buf, err := attemptDownloadPiece(c, pw)
if err != nil {
// Re-enqueue the piece to try again
log.Println(err)
s.removeClient(c)
s.queue <- pw
} else {
// Copy into buffer should not overlap with other workers
copy(s.buf[begin:end], pieceBuf)
s.mux.Lock()
s.piecesDone++
log.Printf("Downloaded piece %d (%d/%d) %0.2f%%\n", pw.index, s.piecesDone, len(d.PieceHashes), float64(s.piecesDone)/float64(len(d.PieceHashes))*100)
s.mux.Unlock()
if s.piecesDone == len(d.PieceHashes) {
close(s.queue)
}
log.Println("Exiting", err)
workQueue <- pw // Put piece back on the queue
return
}
s.mux.Lock()
c.engaged = false
s.mux.Unlock()
err = checkIntegrity(pw, buf)
if err != nil {
log.Printf("Piece #%d failed integrity check\n", pw.index)
workQueue <- pw // Put piece back on the queue
continue
}
results <- &pieceResult{pw.index, buf}
c.have(pw.index)
}
wg.Done()
}
func (d *Download) processQueue(clients []*client, queue chan *pieceWork) []byte {
s := swarm{
clients: clients,
queue: queue,
buf: make([]byte, d.Length),
mux: sync.Mutex{},
}
numWorkers := (len(s.clients) + 1) / 2
log.Printf("Spawning %d workers\n", numWorkers)
wg := sync.WaitGroup{}
wg.Add(numWorkers)
for i := 0; i < numWorkers; i++ {
go s.worker(d, &wg)
}
wg.Wait()
return s.buf
func calculateBoundsForPiece(index, numPieces, length int) (begin int, end int) {
pieceLength := length / numPieces
begin = index * pieceLength
end = begin + pieceLength
return begin, end
}
// Download downloads the torrent
func (t *Torrent) Download() ([]byte, error) {
// Init queues for workers to retrieve work and send results
workQueue := make(chan *pieceWork, len(t.PieceHashes))
results := make(chan *pieceResult, len(t.PieceHashes))
for index, hash := range t.PieceHashes {
length := t.Length / len(t.PieceHashes)
workQueue <- &pieceWork{index, hash, length}
}
// Start workers
for _, peer := range t.Peers {
go t.downloadWorker(peer, workQueue, results)
}
// Collect results into a buffer until full
buf := make([]byte, t.Length)
donePieces := 0
for donePieces < len(t.PieceHashes) {
res := <-results
begin, end := calculateBoundsForPiece(res.index, len(t.PieceHashes), t.Length)
copy(buf[begin:end], res.buf)
donePieces++
percent := float64(donePieces) / float64(len(t.PieceHashes)) * 100
log.Printf("(%0.2f%%) Downloaded piece #%d with %d goroutines\n", percent, res.index, runtime.NumGoroutine())
}
close(workQueue)
return buf, nil
}

View File

@@ -1,4 +1,4 @@
package torrent
package torrentfile
import (
"bytes"
@@ -14,8 +14,8 @@ import (
// Port to listen on
const Port uint16 = 6881
// Torrent encodes the metadata from a .torrent file
type Torrent struct {
// TorrentFile encodes the metadata from a .torrent file
type TorrentFile struct {
Announce string
InfoHash [20]byte
PieceHashes [][20]byte
@@ -37,7 +37,7 @@ type bencodeTorrent struct {
}
// Download downloads a torrent
func (t *Torrent) Download() ([]byte, error) {
func (t *TorrentFile) Download() ([]byte, error) {
var peerID [20]byte
_, err := rand.Read(peerID[:])
if err != nil {
@@ -47,14 +47,14 @@ func (t *Torrent) Download() ([]byte, error) {
peers, err := t.getPeers(peerID, Port)
// peers = append(peers, p2p.Peer{IP: net.IP{127, 0, 0, 1}, Port: 51413})
// peers := []p2p.Peer{{IP: net.IP{127, 0, 0, 1}, Port: 51413}}
downloader := p2p.Download{
torrent := p2p.Torrent{
Peers: peers,
PeerID: peerID,
InfoHash: t.InfoHash,
PieceHashes: t.PieceHashes,
Length: t.Length,
}
buf, err := downloader.Download()
buf, err := torrent.Download()
if err != nil {
return nil, err
}
@@ -62,7 +62,7 @@ func (t *Torrent) Download() ([]byte, error) {
}
// Open parses a torrent file
func Open(r io.Reader) (*Torrent, error) {
func Open(r io.Reader) (*TorrentFile, error) {
bto := bencodeTorrent{}
err := bencode.Unmarshal(r, &bto)
if err != nil {
@@ -101,7 +101,7 @@ func (i *bencodeInfo) splitPieceHashes() ([][20]byte, error) {
return hashes, nil
}
func (bto *bencodeTorrent) toTorrent() (*Torrent, error) {
func (bto *bencodeTorrent) toTorrent() (*TorrentFile, error) {
infoHash, err := bto.Info.hash()
if err != nil {
return nil, err
@@ -110,7 +110,7 @@ func (bto *bencodeTorrent) toTorrent() (*Torrent, error) {
if err != nil {
return nil, err
}
t := Torrent{
t := TorrentFile{
Announce: bto.Announce,
InfoHash: infoHash,
PieceHashes: pieceHashes,

View File

@@ -1,4 +1,4 @@
package torrent
package torrentfile
import (
"testing"
@@ -9,7 +9,7 @@ import (
func TestToTorrent(t *testing.T) {
tests := map[string]struct {
input *bencodeTorrent
output *Torrent
output *TorrentFile
fails bool
}{
"correct conversion": {
@@ -22,7 +22,7 @@ func TestToTorrent(t *testing.T) {
Name: "debian-10.2.0-amd64-netinst.iso",
},
},
output: &Torrent{
output: &TorrentFile{
Announce: "http://bttracker.debian.org:6969/announce",
InfoHash: [20]byte{216, 247, 57, 206, 195, 40, 149, 108, 204, 91, 191, 31, 134, 217, 253, 207, 219, 168, 206, 182},
PieceHashes: [][20]byte{

View File

@@ -1,4 +1,4 @@
package torrent
package torrentfile
import (
"encoding/binary"
@@ -33,7 +33,7 @@ func parsePeers(peersBin string) ([]p2p.Peer, error) {
return peers, nil
}
func (t *Torrent) buildTrackerURL(peerID [20]byte, port uint16) (string, error) {
func (t *TorrentFile) buildTrackerURL(peerID [20]byte, port uint16) (string, error) {
base, err := url.Parse(t.Announce)
if err != nil {
return "", err
@@ -51,7 +51,7 @@ func (t *Torrent) buildTrackerURL(peerID [20]byte, port uint16) (string, error)
return base.String(), nil
}
func (t *Torrent) getPeers(peerID [20]byte, port uint16) ([]p2p.Peer, error) {
func (t *TorrentFile) getPeers(peerID [20]byte, port uint16) ([]p2p.Peer, error) {
url, err := t.buildTrackerURL(peerID, port)
if err != nil {
return nil, err

View File

@@ -1,4 +1,4 @@
package torrent
package torrentfile
import (
"net"
@@ -9,7 +9,7 @@ import (
)
func TestBuildTrackerURL(t *testing.T) {
to := Torrent{
to := TorrentFile{
Announce: "http://bttracker.debian.org:6969/announce",
InfoHash: [20]byte{216, 247, 57, 206, 195, 40, 149, 108, 204, 91, 191, 31, 134, 217, 253, 207, 219, 168, 206, 182},
PieceHashes: [][20]byte{