diff --git a/tunnel/default.go b/tunnel/default.go index 1bae0835..01343261 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -4,7 +4,6 @@ import ( "crypto/sha256" "errors" "fmt" - "io" "sync" "time" @@ -139,16 +138,17 @@ func (t *tun) monitor() { return case <-reconnect.C: for _, node := range t.options.Nodes { - t.Lock() if _, ok := t.links[node]; !ok { + t.Lock() link, err := t.setupLink(node) if err != nil { - log.Debugf("Tunnel failed to establish node link to %s: %v", node, err) + log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) + t.Unlock() continue } t.links[node] = link + t.Unlock() } - t.Unlock() } } } @@ -190,11 +190,9 @@ func (t *tun) process() { } log.Debugf("Sending %+v to %s", newMsg, node) if err := link.Send(newMsg); err != nil { - log.Debugf("Error sending %+v to %s: %v", newMsg, node, err) - if err == io.EOF { - delete(t.links, node) - continue - } + log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, err) + delete(t.links, node) + continue } } t.Unlock() @@ -211,12 +209,10 @@ func (t *tun) listen(link *link) { msg := new(transport.Message) err := link.Recv(msg) if err != nil { - log.Debugf("Tunnel link %s receive error: %v", link.Remote(), err) - if err == io.EOF { - t.Lock() - delete(t.links, link.Remote()) - t.Unlock() - } + log.Debugf("Tunnel link %s receive error: %#v", link.Remote(), err) + t.Lock() + delete(t.links, link.Remote()) + t.Unlock() return } @@ -353,13 +349,10 @@ func (t *tun) keepalive(link *link) { }, }); err != nil { log.Debugf("Error sending keepalive to link %v: %v", link.Remote(), err) - if err == io.EOF { - t.Lock() - delete(t.links, link.Remote()) - t.Unlock() - return - } - // TODO: handle this error + t.Lock() + delete(t.links, link.Remote()) + t.Unlock() + return } } } @@ -428,7 +421,7 @@ func (t *tun) connect() error { // delete the link defer func() { - log.Debugf("Deleting connection from %s", sock.Remote()) + log.Debugf("Tunnel deleting connection from %s", sock.Remote()) t.Lock() delete(t.links, sock.Remote()) t.Unlock() @@ -527,8 +520,9 @@ func (t *tun) Close() error { return nil default: // close all the sockets - for _, s := range t.sockets { + for id, s := range t.sockets { s.Close() + delete(t.sockets, id) } // close the connection close(t.closed) diff --git a/tunnel/tunnel_test.go b/tunnel/tunnel_test.go index 79db046a..08295cf8 100644 --- a/tunnel/tunnel_test.go +++ b/tunnel/tunnel_test.go @@ -9,13 +9,16 @@ import ( ) // testAccept will accept connections on the transport, create a new link and tunnel on top -func testAccept(t *testing.T, tun Tunnel, wg *sync.WaitGroup) { +func testAccept(t *testing.T, tun Tunnel, wait chan struct{}, wg *sync.WaitGroup) { // listen on some virtual address tl, err := tun.Listen("test-tunnel") if err != nil { t.Fatal(err) } + // receiver ready; notify sender + wait <- struct{}{} + // accept a connection c, err := tl.Accept() if err != nil { @@ -46,7 +49,12 @@ func testAccept(t *testing.T, tun Tunnel, wg *sync.WaitGroup) { } // testSend will create a new link to an address and then a tunnel on top -func testSend(t *testing.T, tun Tunnel) { +func testSend(t *testing.T, tun Tunnel, wait chan struct{}, wg *sync.WaitGroup) { + defer wg.Done() + + // wait for the listener to get ready + <-wait + // dial a new session c, err := tun.Dial("test-tunnel") if err != nil { @@ -95,8 +103,6 @@ func TestTunnel(t *testing.T) { } defer tunB.Close() - time.Sleep(time.Millisecond * 50) - // start tunA err = tunA.Connect() if err != nil { @@ -104,51 +110,190 @@ func TestTunnel(t *testing.T) { } defer tunA.Close() - time.Sleep(time.Millisecond * 50) + wait := make(chan struct{}) var wg sync.WaitGroup - // start accepting connections - // on tunnel A wg.Add(1) - go testAccept(t, tunA, &wg) + // start the listener + go testAccept(t, tunB, wait, &wg) - time.Sleep(time.Millisecond * 50) - - // dial and send via B - testSend(t, tunB) + wg.Add(1) + // start the client + go testSend(t, tunA, wait, &wg) // wait until done wg.Wait() } func TestLoopbackTunnel(t *testing.T) { - // create a new tunnel client + // create a new tunnel tun := NewTunnel( Address("127.0.0.1:9096"), Nodes("127.0.0.1:9096"), ) - // start tunB + // start tunnel err := tun.Connect() if err != nil { t.Fatal(err) } defer tun.Close() - time.Sleep(time.Millisecond * 50) + wait := make(chan struct{}) var wg sync.WaitGroup - // start accepting connections - // on tunnel A wg.Add(1) - go testAccept(t, tun, &wg) + // start the listener + go testAccept(t, tun, wait, &wg) - time.Sleep(time.Millisecond * 50) - - // dial and send via B - testSend(t, tun) + wg.Add(1) + // start the client + go testSend(t, tun, wait, &wg) + + // wait until done + wg.Wait() +} + +func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan struct{}, wg *sync.WaitGroup) { + defer wg.Done() + + // listen on some virtual address + tl, err := tun.Listen("test-tunnel") + if err != nil { + t.Fatal(err) + } + + // receiver ready; notify sender + wait <- struct{}{} + + // accept a connection + c, err := tl.Accept() + if err != nil { + t.Fatal(err) + } + + // accept the message and close the tunnel + // we do this to simulate loss of network connection + m := new(transport.Message) + if err := c.Recv(m); err != nil { + t.Fatal(err) + } + tun.Close() + + // re-start tunnel + err = tun.Connect() + if err != nil { + t.Fatal(err) + } + defer tun.Close() + + // listen on some virtual address + tl, err = tun.Listen("test-tunnel") + if err != nil { + t.Fatal(err) + } + + // receiver ready; notify sender + wait <- struct{}{} + + // accept a connection + c, err = tl.Accept() + if err != nil { + t.Fatal(err) + } + + // accept the message + m = new(transport.Message) + if err := c.Recv(m); err != nil { + t.Fatal(err) + } + + // notify sender we have received the message + <-wait +} + +func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan struct{}, wg *sync.WaitGroup) { + defer wg.Done() + + // wait for the listener to get ready + <-wait + + // dial a new session + c, err := tun.Dial("test-tunnel") + if err != nil { + t.Fatal(err) + } + defer c.Close() + + m := transport.Message{ + Header: map[string]string{ + "test": "send", + }, + } + + // send the message + if err := c.Send(&m); err != nil { + t.Fatal(err) + } + + // wait for the listener to get ready + <-wait + + // give it time to reconnect + time.Sleep(2 * ReconnectTime) + + // send the message + if err := c.Send(&m); err != nil { + t.Fatal(err) + } + + // wait for the listener to receive the message + // c.Send merely enqueues the message to the link send queue and returns + // in order to verify it was received we wait for the listener to tell us + wait <- struct{}{} +} + +func TestReconnectTunnel(t *testing.T) { + // create a new tunnel client + tunA := NewTunnel( + Address("127.0.0.1:9096"), + Nodes("127.0.0.1:9097"), + ) + + // create a new tunnel server + tunB := NewTunnel( + Address("127.0.0.1:9097"), + ) + + // start tunnel + err := tunB.Connect() + if err != nil { + t.Fatal(err) + } + + // we manually override the tunnel.ReconnectTime value here + // this is so that we make the reconnects faster than the default 5s + ReconnectTime = 200 * time.Millisecond + + // start tunnel + err = tunA.Connect() + if err != nil { + t.Fatal(err) + } + + wait := make(chan struct{}) + + var wg sync.WaitGroup + + wg.Add(1) + // start tunnel listener + go testBrokenTunAccept(t, tunB, wait, &wg) + + wg.Add(1) + // start tunnel sender + go testBrokenTunSend(t, tunA, wait, &wg) // wait until done wg.Wait()