1
0
mirror of https://github.com/veggiedefender/torrent-client.git synced 2025-11-06 09:29:16 +02:00

Implement formatting requests and parsing pieces

This commit is contained in:
Jesse
2019-12-23 12:56:11 -05:00
parent 7999e23fe1
commit cdfb02a591
5 changed files with 181 additions and 18 deletions

View File

@@ -18,5 +18,8 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
t.Download() err = t.Download()
if err != nil {
log.Fatal(err)
}
} }

View File

@@ -2,6 +2,7 @@ package message
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
) )
@@ -28,6 +29,42 @@ type Message struct {
Payload []byte Payload []byte
} }
// FormatRequest formats the ID and payload for a request message
func FormatRequest(index, begin, length int) *Message {
payload := make([]byte, 12)
binary.BigEndian.PutUint32(payload[0:4], uint32(index))
binary.BigEndian.PutUint32(payload[4:8], uint32(begin))
binary.BigEndian.PutUint32(payload[8:12], uint32(length))
return &Message{
ID: MsgRequest,
Payload: payload,
}
}
// ParsePiece parses a piece message and copies its payload into a buffer
func ParsePiece(index int, buf []byte, msg *Message) (int, error) {
if msg.ID != MsgPiece {
return 0, fmt.Errorf("Expected ID %d, got ID %d", MsgPiece, msg.ID)
}
if len(msg.Payload) < 8 {
return 0, errors.New("Payload too short")
}
parsedIndex := int(binary.BigEndian.Uint32(msg.Payload[0:4]))
if parsedIndex != index {
return 0, fmt.Errorf("Expected index %d, got %d", index, parsedIndex)
}
begin := int(binary.BigEndian.Uint32(msg.Payload[4:8]))
if begin >= len(buf) {
return 0, fmt.Errorf("Begin offset too high. %d >= %d", begin, len(buf))
}
data := msg.Payload[8:]
if begin+len(data) > len(buf) {
return 0, fmt.Errorf("Data too long [%d] for offset %d with length %d", len(data), begin, len(buf))
}
copy(buf[begin:], data)
return len(data), nil
}
// Serialize serializes a message into a buffer of the form // Serialize serializes a message into a buffer of the form
// <length prefix><message ID><payload> // <length prefix><message ID><payload>
// Interprets `nil` as a keep-alive message // Interprets `nil` as a keep-alive message

View File

@@ -7,6 +7,128 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestFormatRequest(t *testing.T) {
msg := FormatRequest(4, 567, 4321)
expected := &Message{
ID: MsgRequest,
Payload: []byte{
0x00, 0x00, 0x00, 0x04, // Index
0x00, 0x00, 0x02, 0x37, // Begin
0x00, 0x00, 0x10, 0xe1, // Length
},
}
assert.Equal(t, expected, msg)
}
func TestParsePiece(t *testing.T) {
tests := map[string]struct {
inputIndex int
inputBuf []byte
inputMsg *Message
outputN int
outputBuf []byte
fails bool
}{
"parse valid piece": {
inputIndex: 4,
inputBuf: make([]byte, 10),
inputMsg: &Message{
ID: MsgPiece,
Payload: []byte{
0x00, 0x00, 0x00, 0x04, // Index
0x00, 0x00, 0x00, 0x02, // Begin
0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // Block
},
},
outputBuf: []byte{0x00, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00, 0x00},
outputN: 6,
fails: false,
},
"wrong message type": {
inputIndex: 4,
inputBuf: make([]byte, 10),
inputMsg: &Message{
ID: MsgChoke,
Payload: []byte{},
},
outputBuf: make([]byte, 10),
outputN: 0,
fails: true,
},
"payload too short": {
inputIndex: 4,
inputBuf: make([]byte, 10),
inputMsg: &Message{
ID: MsgPiece,
Payload: []byte{
0x00, 0x00, 0x00, 0x04, // Index
0x00, 0x00, 0x00, // Malformed offset
},
},
outputBuf: make([]byte, 10),
outputN: 0,
fails: true,
},
"wrong index": {
inputIndex: 4,
inputBuf: make([]byte, 10),
inputMsg: &Message{
ID: MsgPiece,
Payload: []byte{
0x00, 0x00, 0x00, 0x06, // Index is 6, not 4
0x00, 0x00, 0x00, 0x02, // Begin
0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // Block
},
},
outputBuf: make([]byte, 10),
outputN: 0,
fails: true,
},
"offset too high": {
inputIndex: 4,
inputBuf: make([]byte, 10),
inputMsg: &Message{
ID: MsgPiece,
Payload: []byte{
0x00, 0x00, 0x00, 0x04, // Index is 6, not 4
0x00, 0x00, 0x00, 0x0c, // Begin is 12 > 10
0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // Block
},
},
outputBuf: make([]byte, 10),
outputN: 0,
fails: true,
},
"offset ok but payload too long": {
inputIndex: 4,
inputBuf: make([]byte, 10),
inputMsg: &Message{
ID: MsgPiece,
Payload: []byte{
0x00, 0x00, 0x00, 0x04, // Index is 6, not 4
0x00, 0x00, 0x00, 0x02, // Begin is ok
// Block is 10 long but begin=2; too long for input buffer
0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x0a, 0x0b, 0x0c, 0x0d,
},
},
outputBuf: make([]byte, 10),
outputN: 0,
fails: true,
},
}
for _, test := range tests {
n, err := ParsePiece(test.inputIndex, test.inputBuf, test.inputMsg)
if test.fails {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
assert.Equal(t, test.outputBuf, test.inputBuf)
assert.Equal(t, test.outputN, n)
}
}
func TestMessageSerialize(t *testing.T) { func TestMessageSerialize(t *testing.T) {
tests := map[string]struct { tests := map[string]struct {
input *Message input *Message

View File

@@ -2,7 +2,6 @@ package p2p
import ( import (
"crypto/sha1" "crypto/sha1"
"encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"math" "math"
@@ -34,10 +33,11 @@ func (d *Downloader) Download() error {
if err != nil { if err != nil {
return err return err
} }
_, err = d.handshake(conn) h, err := d.handshake(conn)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(h)
choked := false choked := false
pieceSize := d.Length / len(d.PieceHashes) pieceSize := d.Length / len(d.PieceHashes)
@@ -49,15 +49,23 @@ func (d *Downloader) Download() error {
return err return err
} }
if msg.ID != message.MsgPiece {
fmt.Println(msg.String())
} else {
fmt.Println("Received", len(msg.Payload), "bytes")
}
switch msg.ID { switch msg.ID {
case message.MsgChoke: case message.MsgChoke:
choked = true choked = true
case message.MsgUnchoke: case message.MsgUnchoke:
choked = false choked = false
case message.MsgPiece: case message.MsgPiece:
begin := binary.BigEndian.Uint32(msg.Payload[4:8]) n, err := message.ParsePiece(0, buf, msg)
copy(buf[begin:], msg.Payload[8:]) if err != nil {
i += (len(msg.Payload) - 8) return err
}
i += n
} }
if !choked { if !choked {
@@ -66,15 +74,7 @@ func (d *Downloader) Download() error {
remain := pieceSize - i remain := pieceSize - i
length := int(math.Min(float64(16384), float64(pieceSize))) length := int(math.Min(float64(16384), float64(pieceSize)))
length = int(math.Min(float64(remain), float64(length))) length = int(math.Min(float64(remain), float64(length)))
payload := make([]byte, 12) _, err := conn.Write(message.FormatRequest(index, begin, length).Serialize())
binary.BigEndian.PutUint32(payload[0:4], uint32(index))
binary.BigEndian.PutUint32(payload[4:8], uint32(begin))
binary.BigEndian.PutUint32(payload[8:12], uint32(length))
request := message.Message{
ID: message.MsgRequest,
Payload: payload,
}
_, err := conn.Write(request.Serialize())
if err != nil { if err != nil {
return err return err
} }
@@ -82,8 +82,9 @@ func (d *Downloader) Download() error {
} }
s := sha1.Sum(buf) s := sha1.Sum(buf)
fmt.Println(hex.EncodeToString(s[:])) fmt.Printf("Downloaded %d bytes.\n", len(buf))
fmt.Println(hex.EncodeToString(d.PieceHashes[0][:])) fmt.Printf("Got SHA1\t%s\n", hex.EncodeToString(s[:]))
fmt.Printf("Expected\t%s\n", hex.EncodeToString(d.PieceHashes[0][:]))
return nil return nil
} }

View File

@@ -46,8 +46,8 @@ func (t *Torrent) Download() error {
} }
// peers, err := t.getPeers(peerID, Port) // peers, err := t.getPeers(peerID, Port)
// fmt.Println(peers)
peers := []p2p.Peer{{IP: net.IP{127, 0, 0, 1}, Port: 51413}} peers := []p2p.Peer{{IP: net.IP{127, 0, 0, 1}, Port: 51413}}
fmt.Println(peers[:1])
downloader := p2p.Downloader{ downloader := p2p.Downloader{
Peers: peers, Peers: peers,
PeerID: peerID, PeerID: peerID,