diff --git a/main.go b/main.go index b8b387c..0bc29bf 100644 --- a/main.go +++ b/main.go @@ -18,5 +18,8 @@ func main() { if err != nil { log.Fatal(err) } - t.Download() + err = t.Download() + if err != nil { + log.Fatal(err) + } } diff --git a/message/message.go b/message/message.go index 48ad785..6f7aeda 100644 --- a/message/message.go +++ b/message/message.go @@ -2,6 +2,7 @@ package message import ( "encoding/binary" + "errors" "fmt" "io" ) @@ -28,6 +29,42 @@ type Message struct { Payload []byte } +// FormatRequest formats the ID and payload for a request message +func FormatRequest(index, begin, length int) *Message { + payload := make([]byte, 12) + binary.BigEndian.PutUint32(payload[0:4], uint32(index)) + binary.BigEndian.PutUint32(payload[4:8], uint32(begin)) + binary.BigEndian.PutUint32(payload[8:12], uint32(length)) + return &Message{ + ID: MsgRequest, + Payload: payload, + } +} + +// ParsePiece parses a piece message and copies its payload into a buffer +func ParsePiece(index int, buf []byte, msg *Message) (int, error) { + if msg.ID != MsgPiece { + return 0, fmt.Errorf("Expected ID %d, got ID %d", MsgPiece, msg.ID) + } + if len(msg.Payload) < 8 { + return 0, errors.New("Payload too short") + } + parsedIndex := int(binary.BigEndian.Uint32(msg.Payload[0:4])) + if parsedIndex != index { + return 0, fmt.Errorf("Expected index %d, got %d", index, parsedIndex) + } + begin := int(binary.BigEndian.Uint32(msg.Payload[4:8])) + if begin >= len(buf) { + return 0, fmt.Errorf("Begin offset too high. %d >= %d", begin, len(buf)) + } + data := msg.Payload[8:] + if begin+len(data) > len(buf) { + return 0, fmt.Errorf("Data too long [%d] for offset %d with length %d", len(data), begin, len(buf)) + } + copy(buf[begin:], data) + return len(data), nil +} + // Serialize serializes a message into a buffer of the form // // Interprets `nil` as a keep-alive message diff --git a/message/message_test.go b/message/message_test.go index da6a675..7da5c8f 100644 --- a/message/message_test.go +++ b/message/message_test.go @@ -7,6 +7,128 @@ import ( "github.com/stretchr/testify/assert" ) +func TestFormatRequest(t *testing.T) { + msg := FormatRequest(4, 567, 4321) + expected := &Message{ + ID: MsgRequest, + Payload: []byte{ + 0x00, 0x00, 0x00, 0x04, // Index + 0x00, 0x00, 0x02, 0x37, // Begin + 0x00, 0x00, 0x10, 0xe1, // Length + }, + } + assert.Equal(t, expected, msg) +} + +func TestParsePiece(t *testing.T) { + tests := map[string]struct { + inputIndex int + inputBuf []byte + inputMsg *Message + outputN int + outputBuf []byte + fails bool + }{ + "parse valid piece": { + inputIndex: 4, + inputBuf: make([]byte, 10), + inputMsg: &Message{ + ID: MsgPiece, + Payload: []byte{ + 0x00, 0x00, 0x00, 0x04, // Index + 0x00, 0x00, 0x00, 0x02, // Begin + 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // Block + }, + }, + outputBuf: []byte{0x00, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00, 0x00}, + outputN: 6, + fails: false, + }, + "wrong message type": { + inputIndex: 4, + inputBuf: make([]byte, 10), + inputMsg: &Message{ + ID: MsgChoke, + Payload: []byte{}, + }, + outputBuf: make([]byte, 10), + outputN: 0, + fails: true, + }, + "payload too short": { + inputIndex: 4, + inputBuf: make([]byte, 10), + inputMsg: &Message{ + ID: MsgPiece, + Payload: []byte{ + 0x00, 0x00, 0x00, 0x04, // Index + 0x00, 0x00, 0x00, // Malformed offset + }, + }, + outputBuf: make([]byte, 10), + outputN: 0, + fails: true, + }, + "wrong index": { + inputIndex: 4, + inputBuf: make([]byte, 10), + inputMsg: &Message{ + ID: MsgPiece, + Payload: []byte{ + 0x00, 0x00, 0x00, 0x06, // Index is 6, not 4 + 0x00, 0x00, 0x00, 0x02, // Begin + 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // Block + }, + }, + outputBuf: make([]byte, 10), + outputN: 0, + fails: true, + }, + "offset too high": { + inputIndex: 4, + inputBuf: make([]byte, 10), + inputMsg: &Message{ + ID: MsgPiece, + Payload: []byte{ + 0x00, 0x00, 0x00, 0x04, // Index is 6, not 4 + 0x00, 0x00, 0x00, 0x0c, // Begin is 12 > 10 + 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // Block + }, + }, + outputBuf: make([]byte, 10), + outputN: 0, + fails: true, + }, + "offset ok but payload too long": { + inputIndex: 4, + inputBuf: make([]byte, 10), + inputMsg: &Message{ + ID: MsgPiece, + Payload: []byte{ + 0x00, 0x00, 0x00, 0x04, // Index is 6, not 4 + 0x00, 0x00, 0x00, 0x02, // Begin is ok + // Block is 10 long but begin=2; too long for input buffer + 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x0a, 0x0b, 0x0c, 0x0d, + }, + }, + outputBuf: make([]byte, 10), + outputN: 0, + fails: true, + }, + } + + for _, test := range tests { + n, err := ParsePiece(test.inputIndex, test.inputBuf, test.inputMsg) + if test.fails { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + assert.Equal(t, test.outputBuf, test.inputBuf) + assert.Equal(t, test.outputN, n) + } +} + func TestMessageSerialize(t *testing.T) { tests := map[string]struct { input *Message diff --git a/p2p/p2p.go b/p2p/p2p.go index 12f0cbf..67548a5 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -2,7 +2,6 @@ package p2p import ( "crypto/sha1" - "encoding/binary" "encoding/hex" "fmt" "math" @@ -34,10 +33,11 @@ func (d *Downloader) Download() error { if err != nil { return err } - _, err = d.handshake(conn) + h, err := d.handshake(conn) if err != nil { return err } + fmt.Println(h) choked := false pieceSize := d.Length / len(d.PieceHashes) @@ -49,15 +49,23 @@ func (d *Downloader) Download() error { return err } + if msg.ID != message.MsgPiece { + fmt.Println(msg.String()) + } else { + fmt.Println("Received", len(msg.Payload), "bytes") + } + switch msg.ID { case message.MsgChoke: choked = true case message.MsgUnchoke: choked = false case message.MsgPiece: - begin := binary.BigEndian.Uint32(msg.Payload[4:8]) - copy(buf[begin:], msg.Payload[8:]) - i += (len(msg.Payload) - 8) + n, err := message.ParsePiece(0, buf, msg) + if err != nil { + return err + } + i += n } if !choked { @@ -66,15 +74,7 @@ func (d *Downloader) Download() error { remain := pieceSize - i length := int(math.Min(float64(16384), float64(pieceSize))) length = int(math.Min(float64(remain), float64(length))) - payload := make([]byte, 12) - binary.BigEndian.PutUint32(payload[0:4], uint32(index)) - binary.BigEndian.PutUint32(payload[4:8], uint32(begin)) - binary.BigEndian.PutUint32(payload[8:12], uint32(length)) - request := message.Message{ - ID: message.MsgRequest, - Payload: payload, - } - _, err := conn.Write(request.Serialize()) + _, err := conn.Write(message.FormatRequest(index, begin, length).Serialize()) if err != nil { return err } @@ -82,8 +82,9 @@ func (d *Downloader) Download() error { } s := sha1.Sum(buf) - fmt.Println(hex.EncodeToString(s[:])) - fmt.Println(hex.EncodeToString(d.PieceHashes[0][:])) + fmt.Printf("Downloaded %d bytes.\n", len(buf)) + fmt.Printf("Got SHA1\t%s\n", hex.EncodeToString(s[:])) + fmt.Printf("Expected\t%s\n", hex.EncodeToString(d.PieceHashes[0][:])) return nil } diff --git a/torrent/torrent.go b/torrent/torrent.go index 44a642b..4356d68 100644 --- a/torrent/torrent.go +++ b/torrent/torrent.go @@ -46,8 +46,8 @@ func (t *Torrent) Download() error { } // peers, err := t.getPeers(peerID, Port) - // fmt.Println(peers) peers := []p2p.Peer{{IP: net.IP{127, 0, 0, 1}, Port: 51413}} + fmt.Println(peers[:1]) downloader := p2p.Downloader{ Peers: peers, PeerID: peerID,