1
0
mirror of https://github.com/veggiedefender/torrent-client.git synced 2025-11-06 17:39:54 +02:00

Move handshake into its own package

This commit is contained in:
Jesse Li
2019-12-24 10:39:22 -05:00
parent aa8ef0ed76
commit c20965ebf5
5 changed files with 18 additions and 17 deletions

View File

@@ -1,4 +1,4 @@
package message package handshake
import ( import (
"errors" "errors"
@@ -25,8 +25,8 @@ func (h *Handshake) Serialize() []byte {
return buf return buf
} }
// ReadHandshake parses a message from a stream. Returns `nil` on keep-alive message // Read parses a message from a stream. Returns `nil` on keep-alive message
func ReadHandshake(r io.Reader) (*Handshake, error) { func Read(r io.Reader) (*Handshake, error) {
lengthBuf := make([]byte, 1) lengthBuf := make([]byte, 1)
_, err := io.ReadFull(r, lengthBuf) _, err := io.ReadFull(r, lengthBuf)
if err != nil { if err != nil {

View File

@@ -1,4 +1,4 @@
package message package handshake
import ( import (
"bytes" "bytes"
@@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestHandshakeSerialize(t *testing.T) { func TestSerialize(t *testing.T) {
tests := map[string]struct { tests := map[string]struct {
input *Handshake input *Handshake
output []byte output []byte
@@ -28,7 +28,7 @@ func TestHandshakeSerialize(t *testing.T) {
} }
} }
func TestReadHandshake(t *testing.T) { func TestRead(t *testing.T) {
tests := map[string]struct { tests := map[string]struct {
input []byte input []byte
output *Handshake output *Handshake
@@ -62,7 +62,7 @@ func TestReadHandshake(t *testing.T) {
for _, test := range tests { for _, test := range tests {
reader := bytes.NewReader(test.input) reader := bytes.NewReader(test.input)
m, err := ReadHandshake(reader) m, err := Read(reader)
if test.fails { if test.fails {
assert.NotNil(t, err) assert.NotNil(t, err)
} else { } else {

View File

@@ -80,8 +80,8 @@ func (m *Message) Serialize() []byte {
return buf return buf
} }
// ReadMessage parses a message from a stream. Returns `nil` on keep-alive message // Read parses a message from a stream. Returns `nil` on keep-alive message
func ReadMessage(r io.Reader) (*Message, error) { func Read(r io.Reader) (*Message, error) {
lengthBuf := make([]byte, 4) lengthBuf := make([]byte, 4)
_, err := io.ReadFull(r, lengthBuf) _, err := io.ReadFull(r, lengthBuf)
if err != nil { if err != nil {

View File

@@ -129,7 +129,7 @@ func TestParsePiece(t *testing.T) {
} }
} }
func TestMessageSerialize(t *testing.T) { func TestSerialize(t *testing.T) {
tests := map[string]struct { tests := map[string]struct {
input *Message input *Message
output []byte output []byte
@@ -150,7 +150,7 @@ func TestMessageSerialize(t *testing.T) {
} }
} }
func TestReadMessage(t *testing.T) { func TestRead(t *testing.T) {
tests := map[string]struct { tests := map[string]struct {
input []byte input []byte
output *Message output *Message
@@ -180,7 +180,7 @@ func TestReadMessage(t *testing.T) {
for _, test := range tests { for _, test := range tests {
reader := bytes.NewReader(test.input) reader := bytes.NewReader(test.input)
m, err := ReadMessage(reader) m, err := Read(reader)
if test.fails { if test.fails {
assert.NotNil(t, err) assert.NotNil(t, err)
} else { } else {
@@ -190,7 +190,7 @@ func TestReadMessage(t *testing.T) {
} }
} }
func TestMessageString(t *testing.T) { func TestString(t *testing.T) {
tests := []struct { tests := []struct {
input *Message input *Message
output string output string

View File

@@ -8,6 +8,7 @@ import (
"net" "net"
"strconv" "strconv"
"github.com/veggiedefender/torrent-client/handshake"
"github.com/veggiedefender/torrent-client/message" "github.com/veggiedefender/torrent-client/message"
) )
@@ -44,7 +45,7 @@ func (d *Downloader) Download() error {
buf := make([]byte, pieceSize) buf := make([]byte, pieceSize)
i := 0 i := 0
for i < pieceSize { for i < pieceSize {
msg, err := message.ReadMessage(conn) msg, err := message.Read(conn)
if err != nil { if err != nil {
return err return err
} }
@@ -98,8 +99,8 @@ func (p *Peer) connect(peerID [20]byte, infoHash [20]byte) (net.Conn, error) {
return conn, nil return conn, nil
} }
func (d *Downloader) handshake(conn net.Conn) (*message.Handshake, error) { func (d *Downloader) handshake(conn net.Conn) (*handshake.Handshake, error) {
req := message.Handshake{ req := handshake.Handshake{
Pstr: "BitTorrent protocol", Pstr: "BitTorrent protocol",
InfoHash: d.InfoHash, InfoHash: d.InfoHash,
PeerID: d.PeerID, PeerID: d.PeerID,
@@ -109,7 +110,7 @@ func (d *Downloader) handshake(conn net.Conn) (*message.Handshake, error) {
return nil, err return nil, err
} }
res, err := message.ReadHandshake(conn) res, err := handshake.Read(conn)
if err != nil { if err != nil {
return nil, err return nil, err
} }