diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 0000000..50c3687 --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,208 @@ +package client + +import ( + "net" + "testing" + + "github.com/veggiedefender/torrent-client/bitfield" + "github.com/veggiedefender/torrent-client/handshake" + + "github.com/veggiedefender/torrent-client/message" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func createClientAndServer(t *testing.T) (clientConn, serverConn net.Conn) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.Nil(t, err) + + // net.Dial does not block, so we need this signalling channel to make sure + // we don't return before serverConn is ready + done := make(chan struct{}) + go func() { + defer ln.Close() + serverConn, err = ln.Accept() + require.Nil(t, err) + done <- struct{}{} + }() + clientConn, err = net.Dial("tcp", ln.Addr().String()) + <-done + + return clientConn, serverConn +} + +func TestRecvBitfield(t *testing.T) { + tests := map[string]struct { + msg []byte + output bitfield.Bitfield + fails bool + }{ + "successful bitfield": { + msg: []byte{0x00, 0x00, 0x00, 0x06, 5, 1, 2, 3, 4, 5}, + output: bitfield.Bitfield{1, 2, 3, 4, 5}, + fails: false, + }, + "message is not a bitfield": { + msg: []byte{0x00, 0x00, 0x00, 0x06, 99, 1, 2, 3, 4, 5}, + output: nil, + fails: true, + }, + } + + for _, test := range tests { + clientConn, serverConn := createClientAndServer(t) + serverConn.Write(test.msg) + + bf, err := recvBitfield(clientConn) + + if test.fails { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, bf, test.output) + } + } +} + +func TestCompleteHandshake(t *testing.T) { + tests := map[string]struct { + clientInfohash [20]byte + clientPeerID [20]byte + serverHandshake []byte + output *handshake.Handshake + fails bool + }{ + "successful handshake": { + clientInfohash: [20]byte{134, 212, 200, 0, 36, 164, 105, 190, 76, 80, 188, 90, 16, 44, 247, 23, 128, 49, 0, 116}, + clientPeerID: [20]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, + serverHandshake: []byte{19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114, 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 0, 0, 0, 134, 212, 200, 0, 36, 164, 105, 190, 76, 80, 188, 90, 16, 44, 247, 23, 128, 49, 0, 116, 45, 83, 89, 48, 48, 49, 48, 45, 192, 125, 147, 203, 136, 32, 59, 180, 253, 168, 193, 19}, + output: &handshake.Handshake{ + Pstr: "BitTorrent protocol", + InfoHash: [20]byte{134, 212, 200, 0, 36, 164, 105, 190, 76, 80, 188, 90, 16, 44, 247, 23, 128, 49, 0, 116}, + PeerID: [20]byte{45, 83, 89, 48, 48, 49, 48, 45, 192, 125, 147, 203, 136, 32, 59, 180, 253, 168, 193, 19}, + }, + fails: false, + }, + "wrong infohash": { + clientInfohash: [20]byte{134, 212, 200, 0, 36, 164, 105, 190, 76, 80, 188, 90, 16, 44, 247, 23, 128, 49, 0, 116}, + clientPeerID: [20]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, + serverHandshake: []byte{19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114, 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 0, 0, 0, 0xde, 0xe8, 0x6a, 0x7f, 0xa6, 0xf2, 0x86, 0xa9, 0xd7, 0x4c, 0x36, 0x20, 0x14, 0x61, 0x6a, 0x0f, 0xf5, 0xe4, 0x84, 0x3d, 45, 83, 89, 48, 48, 49, 48, 45, 192, 125, 147, 203, 136, 32, 59, 180, 253, 168, 193, 19}, + output: nil, + fails: true, + }, + } + + for _, test := range tests { + clientConn, serverConn := createClientAndServer(t) + serverConn.Write(test.serverHandshake) + + h, err := completeHandshake(clientConn, test.clientInfohash, test.clientPeerID) + + if test.fails { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, h, test.output) + } + } +} + +func TestRead(t *testing.T) { + clientConn, serverConn := createClientAndServer(t) + client := Client{Conn: clientConn} + + msgBytes := []byte{ + 0x00, 0x00, 0x00, 0x05, + 4, + 0x00, 0x00, 0x05, 0x3c, + } + expected := &message.Message{ + ID: message.MsgHave, + Payload: []byte{0x00, 0x00, 0x05, 0x3c}, + } + _, err := serverConn.Write(msgBytes) + require.Nil(t, err) + + msg, err := client.Read() + assert.Equal(t, expected, msg) +} + +func TestSendRequest(t *testing.T) { + clientConn, serverConn := createClientAndServer(t) + client := Client{Conn: clientConn} + err := client.SendRequest(1, 2, 3) + assert.Nil(t, err) + expected := []byte{ + 0x00, 0x00, 0x00, 0x0d, + 6, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x03, + } + buf := make([]byte, len(expected)) + _, err = serverConn.Read(buf) + assert.Nil(t, err) + assert.Equal(t, expected, buf) +} + +func TestSendInterested(t *testing.T) { + clientConn, serverConn := createClientAndServer(t) + client := Client{Conn: clientConn} + err := client.SendInterested() + assert.Nil(t, err) + expected := []byte{ + 0x00, 0x00, 0x00, 0x01, + 2, + } + buf := make([]byte, len(expected)) + _, err = serverConn.Read(buf) + assert.Nil(t, err) + assert.Equal(t, expected, buf) +} + +func TestSendNotInterested(t *testing.T) { + clientConn, serverConn := createClientAndServer(t) + client := Client{Conn: clientConn} + err := client.SendNotInterested() + assert.Nil(t, err) + expected := []byte{ + 0x00, 0x00, 0x00, 0x01, + 3, + } + buf := make([]byte, len(expected)) + _, err = serverConn.Read(buf) + assert.Nil(t, err) + assert.Equal(t, expected, buf) +} + +func TestSendUnchoke(t *testing.T) { + clientConn, serverConn := createClientAndServer(t) + client := Client{Conn: clientConn} + err := client.SendUnchoke() + assert.Nil(t, err) + expected := []byte{ + 0x00, 0x00, 0x00, 0x01, + 1, + } + buf := make([]byte, len(expected)) + _, err = serverConn.Read(buf) + assert.Nil(t, err) + assert.Equal(t, expected, buf) +} + +func TestSendHave(t *testing.T) { + clientConn, serverConn := createClientAndServer(t) + client := Client{Conn: clientConn} + err := client.SendHave(1340) + assert.Nil(t, err) + expected := []byte{ + 0x00, 0x00, 0x00, 0x05, + 4, + 0x00, 0x00, 0x05, 0x3c, + } + buf := make([]byte, len(expected)) + _, err = serverConn.Read(buf) + assert.Nil(t, err) + assert.Equal(t, expected, buf) +}