1
0
mirror of https://github.com/veggiedefender/torrent-client.git synced 2025-11-06 09:29:16 +02:00

Implement parsing peer messages

This commit is contained in:
Jesse Li
2019-12-22 12:51:59 -05:00
parent 941fba7a64
commit 11314c13dc
2 changed files with 136 additions and 0 deletions

67
peer/peer.go Normal file
View File

@@ -0,0 +1,67 @@
package peer
import (
"encoding/binary"
"io"
)
// Message ID
const (
MsgChoke uint8 = 0
MsgUnchoke uint8 = 1
MsgInterested uint8 = 2
MsgNotInterested uint8 = 3
MsgHave uint8 = 4
MsgBitfield uint8 = 5
MsgRequest uint8 = 6
MsgPiece uint8 = 7
MsgCancel uint8 = 8
MsgPort uint8 = 9
)
// Message m
type Message struct {
ID uint8
Payload []byte
}
// Serialize serializes a message into a buffer of the form
// <length prefix><message ID><payload>
func (m *Message) Serialize() []byte {
if m == nil {
return make([]byte, 4)
}
length := uint32(len(m.Payload) + 1) // +1 for id
buf := make([]byte, 4+length)
binary.BigEndian.PutUint32(buf[0:4], length)
buf[4] = byte(m.ID)
copy(buf[5:], m.Payload)
return buf
}
// Read parses a message from a stream. Returns `nil` on keep-alive message
func Read(r io.Reader) (*Message, error) {
lengthBuf := make([]byte, 4)
_, err := io.ReadFull(r, lengthBuf)
if err != nil {
return nil, err
}
length := binary.BigEndian.Uint32(lengthBuf)
if length == 0 {
return nil, nil
}
messageBuf := make([]byte, length)
_, err = io.ReadFull(r, messageBuf)
if err != nil {
return nil, err
}
m := Message{
ID: messageBuf[0],
Payload: messageBuf[1:],
}
return &m, nil
}

69
peer/peer_test.go Normal file
View File

@@ -0,0 +1,69 @@
package peer
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSerialize(t *testing.T) {
tests := map[string]struct {
input *Message
output []byte
}{
"serialize message": {
input: &Message{ID: MsgHave, Payload: []byte{1, 2, 3, 4}},
output: []byte{0, 0, 0, 5, 4, 1, 2, 3, 4},
},
"serialize keep-alive": {
input: nil,
output: []byte{0, 0, 0, 0},
},
}
for _, test := range tests {
buf := test.input.Serialize()
assert.Equal(t, test.output, buf)
}
}
func TestRead(t *testing.T) {
tests := map[string]struct {
input []byte
output *Message
fails bool
}{
"parse normal message intro struct": {
input: []byte{0, 0, 0, 5, 4, 1, 2, 3, 4},
output: &Message{ID: MsgHave, Payload: []byte{1, 2, 3, 4}},
fails: false,
},
"parse keep-alive into nil": {
input: []byte{0, 0, 0, 0},
output: nil,
fails: false,
},
"length too short": {
input: []byte{1, 2, 3},
output: nil,
fails: true,
},
"buffer too short for length": {
input: []byte{0, 0, 0, 5, 4, 1, 2},
output: nil,
fails: true,
},
}
for _, test := range tests {
reader := bytes.NewReader(test.input)
m, err := Read(reader)
if test.fails {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
assert.Equal(t, test.output, m)
}
}