diff --git a/handshake/handshake.go b/handshake/handshake.go index e8a0f6d..c028348 100644 --- a/handshake/handshake.go +++ b/handshake/handshake.go @@ -1,6 +1,7 @@ package handshake import ( + "bufio" "errors" "io" ) @@ -35,7 +36,7 @@ func (h *Handshake) Serialize() []byte { } // Read parses a message from a stream. Returns `nil` on keep-alive message -func Read(r io.Reader) (*Handshake, error) { +func Read(r *bufio.Reader) (*Handshake, error) { lengthBuf := make([]byte, 1) _, err := io.ReadFull(r, lengthBuf) if err != nil { diff --git a/handshake/handshake_test.go b/handshake/handshake_test.go index 054f055..e2bc237 100644 --- a/handshake/handshake_test.go +++ b/handshake/handshake_test.go @@ -1,6 +1,7 @@ package handshake import ( + "bufio" "bytes" "testing" @@ -73,7 +74,7 @@ func TestRead(t *testing.T) { } for _, test := range tests { - reader := bytes.NewReader(test.input) + reader := bufio.NewReader(bytes.NewReader(test.input)) m, err := Read(reader) if test.fails { assert.NotNil(t, err) diff --git a/main.go b/main.go index 0bc29bf..f3fc3a5 100644 --- a/main.go +++ b/main.go @@ -8,17 +8,30 @@ import ( ) func main() { - file, err := os.Open(os.Args[1]) - if err != nil { - log.Fatal(err) - } - defer file.Close() + inPath := os.Args[1] + outPath := os.Args[2] - t, err := torrent.Open(file) + inFile, err := os.Open(inPath) if err != nil { log.Fatal(err) } - err = t.Download() + defer inFile.Close() + + t, err := torrent.Open(inFile) + if err != nil { + log.Fatal(err) + } + buf, err := t.Download() + if err != nil { + log.Fatal(err) + } + + outFile, err := os.Create(outPath) + if err != nil { + log.Fatal(err) + } + defer outFile.Close() + _, err = outFile.Write(buf) if err != nil { log.Fatal(err) } diff --git a/message/message.go b/message/message.go index ffeccc4..608b6e3 100644 --- a/message/message.go +++ b/message/message.go @@ -1,6 +1,7 @@ package message import ( + "bufio" "encoding/binary" "errors" "fmt" @@ -84,7 +85,7 @@ func (m *Message) Serialize() []byte { } // Read parses a message from a stream. Returns `nil` on keep-alive message -func Read(r io.Reader) (*Message, error) { +func Read(r *bufio.Reader) (*Message, error) { lengthBuf := make([]byte, 4) _, err := io.ReadFull(r, lengthBuf) if err != nil { @@ -115,7 +116,6 @@ func (m *Message) String() string { if m == nil { return "KeepAlive" } - switch m.ID { case MsgChoke: return "Choke" diff --git a/message/message_test.go b/message/message_test.go index 4760a16..40bde91 100644 --- a/message/message_test.go +++ b/message/message_test.go @@ -1,6 +1,7 @@ package message import ( + "bufio" "bytes" "testing" @@ -179,7 +180,7 @@ func TestRead(t *testing.T) { } for _, test := range tests { - reader := bytes.NewReader(test.input) + reader := bufio.NewReader(bytes.NewReader(test.input)) m, err := Read(reader) if test.fails { assert.NotNil(t, err) diff --git a/p2p/client.go b/p2p/client.go index 4a8cd74..4a9d1b8 100644 --- a/p2p/client.go +++ b/p2p/client.go @@ -1,10 +1,11 @@ package p2p import ( + "bufio" + "encoding/binary" "fmt" "net" "strconv" - "sync" "time" "github.com/veggiedefender/torrent-client/message" @@ -13,13 +14,17 @@ import ( ) type client struct { + peer Peer + infoHash [20]byte + peerID [20]byte conn net.Conn + reader *bufio.Reader bitfield message.Bitfield - Choked bool - Mux sync.Mutex + choked bool + engaged bool } -func completeHandshake(conn net.Conn, infohash, peerID [20]byte) (*handshake.Handshake, error) { +func completeHandshake(conn net.Conn, r *bufio.Reader, infohash, peerID [20]byte) (*handshake.Handshake, error) { conn.SetDeadline(time.Now().Local().Add(3 * time.Second)) defer conn.SetDeadline(time.Time{}) // Disable the deadline @@ -29,18 +34,18 @@ func completeHandshake(conn net.Conn, infohash, peerID [20]byte) (*handshake.Han return nil, err } - res, err := handshake.Read(conn) + res, err := handshake.Read(r) if err != nil { return nil, err } return res, nil } -func recvBitfield(conn net.Conn) (message.Bitfield, error) { +func recvBitfield(conn net.Conn, r *bufio.Reader) (message.Bitfield, error) { conn.SetDeadline(time.Now().Local().Add(5 * time.Second)) defer conn.SetDeadline(time.Time{}) // Disable the deadline - msg, err := message.Read(conn) + msg, err := message.Read(r) if err != nil { return nil, err } @@ -58,22 +63,72 @@ func newClient(peer Peer, peerID, infoHash [20]byte) (*client, error) { if err != nil { return nil, err } - _, err = completeHandshake(conn, infoHash, peerID) + reader := bufio.NewReader(conn) + + _, err = completeHandshake(conn, reader, infoHash, peerID) if err != nil { return nil, err } - bf, err := recvBitfield(conn) + + bf, err := recvBitfield(conn, reader) if err != nil { return nil, err } + return &client{ + peer: peer, + infoHash: infoHash, + peerID: peerID, conn: conn, + reader: reader, bitfield: bf, - Mux: sync.Mutex{}, - Choked: true, + choked: true, + engaged: false, }, nil } func (c *client) hasPiece(index int) bool { return c.bitfield.HasPiece(index) } + +func (c *client) hasNext() bool { + fmt.Println(c.reader.Buffered() > 0) + return c.reader.Buffered() > 0 +} + +func (c *client) read() (*message.Message, error) { + msg, err := message.Read(c.reader) + return msg, err +} + +func (c *client) request(index, begin, length int) error { + req := message.FormatRequest(index, begin, length) + _, err := c.conn.Write(req.Serialize()) + return err +} + +func (c *client) interested() error { + msg := message.Message{ID: message.MsgInterested} + _, err := c.conn.Write(msg.Serialize()) + return err +} + +func (c *client) notInterested() error { + msg := message.Message{ID: message.MsgNotInterested} + _, err := c.conn.Write(msg.Serialize()) + return err +} + +func (c *client) unchoke() error { + msg := message.Message{ID: message.MsgUnchoke} + _, err := c.conn.Write(msg.Serialize()) + return err +} + +func (c *client) have(index int) error { + pl := make([]byte, 4) + binary.BigEndian.PutUint32(pl, uint32(index)) + msg := message.Message{ID: message.MsgHave, Payload: pl} + _, err := c.conn.Write(msg.Serialize()) + return err +} diff --git a/p2p/p2p.go b/p2p/p2p.go index e1da35e..f041194 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -1,11 +1,18 @@ package p2p import ( + "bytes" + "crypto/sha1" "fmt" + "log" "net" "sync" + + "github.com/veggiedefender/torrent-client/message" ) +const maxBlockSize = 32768 + // Peer encodes connection information for a peer type Peer struct { IP net.IP @@ -27,25 +34,26 @@ type pieceWork struct { } type swarm struct { - clients []*client - queue chan *pieceWork - mux sync.Mutex + clients []*client + queue chan *pieceWork + buf []byte + piecesDone int + mux sync.Mutex } // Download downloads a torrent -func (d *Download) Download() error { +func (d *Download) Download() ([]byte, error) { clients := d.initClients() if len(clients) == 0 { - return fmt.Errorf("Could not connect to any of %d clients", len(d.Peers)) + return nil, fmt.Errorf("Could not connect to any of %d clients", len(d.Peers)) } + log.Printf("Connected to %d clients\n", len(clients)) queue := make(chan *pieceWork, len(d.PieceHashes)) for index, hash := range d.PieceHashes { queue <- &pieceWork{index, hash} } - processQueue(clients, queue) - - return nil + return d.processQueue(clients, queue), nil } func (d *Download) initClients() []*client { @@ -62,6 +70,7 @@ func (d *Download) initClients() []*client { }(p) } + // Gather clients into a slice clients := make([]*client, 0) for range d.Peers { client := <-c @@ -72,15 +81,177 @@ func (d *Download) initClients() []*client { return clients } -func (s *swarm) selectClient(index int) *client { - return s.clients[0] +func (s *swarm) selectClient(index int) (*client, error) { + for _, c := range s.clients { + if !c.engaged && !c.choked && c.hasPiece(index) { + return c, nil + } + } + for _, c := range s.clients { + if !c.engaged && c.hasPiece(index) { + return c, nil + } + } + return nil, fmt.Errorf("Could not find client for piece %d", index) } -func processQueue(clients []*client, queue chan *pieceWork) { - s := swarm{clients, queue, sync.Mutex{}} - for pw := range s.queue { - client := s.selectClient(pw.index) - fmt.Println(client.conn.RemoteAddr()) - break +func calculateBoundsForPiece(index, numPieces, length int) (int, int) { + pieceLength := length / numPieces + begin := index * pieceLength + end := begin + pieceLength + if end > length-1 { + end = length - 1 } + return begin, end +} + +func checkIntegrity(pw *pieceWork, buf []byte) error { + hash := sha1.Sum(buf) + if !bytes.Equal(hash[:], pw.hash[:]) { + return fmt.Errorf("Index %d failed integrity check", pw.index) + } + return nil +} + +func downloadPiece(c *client, pw *pieceWork, pieceLength int) ([]byte, error) { + buf := make([]byte, pieceLength) + c.unchoke() + c.interested() + offset := 0 + for c.hasNext() { + msg, err := c.read() // this call blocks + if err != nil { + return nil, err + } + if msg == nil { // keep-alive + continue + } + fmt.Println("CATCHING UP ON", msg) + switch msg.ID { + case message.MsgUnchoke: + c.choked = false + case message.MsgChoke: + c.choked = true + } + } + + for offset < pieceLength { + if !c.choked { + blockSize := maxBlockSize + if pieceLength-offset < blockSize { + // Last block might be shorter than the typical block + blockSize = pieceLength - offset + } + c.request(pw.index, offset, blockSize) + } + + msg, err := c.read() // this call blocks + if err != nil { + return nil, err + } + if msg == nil { // keep-alive + continue + } + if msg.ID != message.MsgPiece { + fmt.Println(msg) + } + switch msg.ID { + case message.MsgUnchoke: + c.choked = false + case message.MsgChoke: + c.choked = true + case message.MsgPiece: + n, err := message.ParsePiece(pw.index, buf, msg) + if err != nil { + return nil, err + } + offset += n + } + } + + c.have(pw.index) + c.notInterested() + + err := checkIntegrity(pw, buf) + if err != nil { + return nil, err + } + + return buf, nil +} + +func (s *swarm) removeClient(c *client) { + log.Printf("Removing client. %d clients remaining\n", len(s.clients)) + s.mux.Lock() + var i int + for i = 0; i < len(s.clients); i++ { + if s.clients[i] == c { + break + } + } + s.clients = append(s.clients[:i], s.clients[i+1:]...) + s.mux.Unlock() +} + +func (s *swarm) worker(d *Download, wg *sync.WaitGroup) { + for pw := range s.queue { + s.mux.Lock() + c, err := s.selectClient(pw.index) + if err != nil { + // Re-enqueue the piece to try again + s.queue <- pw + s.mux.Unlock() + continue + } + c.engaged = true + s.mux.Unlock() + + begin, end := calculateBoundsForPiece(pw.index, len(d.PieceHashes), d.Length) + pieceLength := end - begin + pieceBuf, err := downloadPiece(c, pw, pieceLength) + if err != nil { + // Re-enqueue the piece to try again + log.Println(err) + s.removeClient(c) + s.queue <- pw + } else { + // Copy into buffer should not overlap with other workers + copy(s.buf[begin:end], pieceBuf) + + s.mux.Lock() + s.piecesDone++ + log.Printf("Downloaded piece %d (%d/%d) %0.2f%%\n", pw.index, s.piecesDone, len(d.PieceHashes), float64(s.piecesDone)/float64(len(d.PieceHashes))) + s.mux.Unlock() + + if s.piecesDone == len(d.PieceHashes) { + close(s.queue) + } + } + + s.mux.Lock() + c.engaged = false + s.mux.Unlock() + } + wg.Done() +} + +func (d *Download) processQueue(clients []*client, queue chan *pieceWork) []byte { + s := swarm{ + clients: clients, + queue: queue, + buf: make([]byte, d.Length), + mux: sync.Mutex{}, + } + + numWorkers := len(s.clients) / 2 + log.Printf("Spawning %d workers\n", numWorkers) + wg := sync.WaitGroup{} + wg.Add(numWorkers) + + for i := 0; i < numWorkers; i++ { + go s.worker(d, &wg) + } + + wg.Wait() + return s.buf } diff --git a/torrent/torrent.go b/torrent/torrent.go index 12ed745..8c5d3de 100644 --- a/torrent/torrent.go +++ b/torrent/torrent.go @@ -6,7 +6,6 @@ import ( "crypto/sha1" "fmt" "io" - "net" "github.com/jackpal/bencode-go" "github.com/veggiedefender/torrent-client/p2p" @@ -38,15 +37,15 @@ type bencodeTorrent struct { } // Download downloads a torrent -func (t *Torrent) Download() error { +func (t *Torrent) Download() ([]byte, error) { var peerID [20]byte _, err := rand.Read(peerID[:]) if err != nil { - return err + return nil, err } - // peers, err := t.getPeers(peerID, Port) - peers := []p2p.Peer{{IP: net.IP{127, 0, 0, 1}, Port: 51413}} + peers, err := t.getPeers(peerID, Port) + // peers := []p2p.Peer{{IP: net.IP{127, 0, 0, 1}, Port: 51413}} downloader := p2p.Download{ Peers: peers, PeerID: peerID, @@ -54,8 +53,11 @@ func (t *Torrent) Download() error { PieceHashes: t.PieceHashes, Length: t.Length, } - err = downloader.Download() - return err + buf, err := downloader.Download() + if err != nil { + return nil, err + } + return buf, nil } // Open parses a torrent file