From 6ab86c9e576139c337ed03d5b0b17e85a5e93efa Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Wed, 28 Aug 2019 23:12:22 +0100 Subject: [PATCH] Don't process unless connected, and only fire loopback messages back up the loopback --- tunnel/default.go | 106 +++++++++++++++++++++++++++++++-------------- tunnel/listener.go | 2 + tunnel/socket.go | 7 ++- 3 files changed, 82 insertions(+), 33 deletions(-) diff --git a/tunnel/default.go b/tunnel/default.go index 69a5d878..bd6af7dd 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -49,8 +49,18 @@ type tun struct { type link struct { transport.Socket - id string - loopback bool + // unique id of this link e.g uuid + // which we define for ourselves + id string + // whether its a loopback connection + loopback bool + // whether its actually connected + // dialled side sets it to connected + // after sending the message. the + // listener waits for the connect + connected bool + // the last time we received a keepalive + // on this link from the remote side lastKeepAlive time.Time } @@ -190,9 +200,25 @@ func (t *tun) process() { log.Debugf("No links to send to") } for node, link := range t.links { + // if the link is not connected skip it + if !link.connected { + log.Debugf("Link for node %s not connected", node) + continue + } + + // if the link was a loopback accepted connection + // and the message is being sent outbound via + // a dialled connection don't use this link if link.loopback && msg.outbound { continue } + + // if the message was being returned by the loopback listener + // send it back up the loopback link only + if msg.loopback && !link.loopback { + continue + } + log.Debugf("Sending %+v to %s", newMsg, node) if err := link.Send(newMsg); err != nil { log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, err) @@ -209,15 +235,22 @@ func (t *tun) process() { // process incoming messages func (t *tun) listen(link *link) { + // remove the link on exit + defer func() { + log.Debugf("Tunnel deleting connection from %s", link.Remote()) + t.Lock() + delete(t.links, link.Remote()) + t.Unlock() + }() + + // let us know if its a loopback + var loopback bool + for { // process anything via the net interface msg := new(transport.Message) - err := link.Recv(msg) - if err != nil { + if err := link.Recv(msg); err != nil { log.Debugf("Tunnel link %s receive error: %#v", link.Remote(), err) - t.Lock() - delete(t.links, link.Remote()) - t.Unlock() return } @@ -232,11 +265,18 @@ func (t *tun) listen(link *link) { // are we connecting to ourselves? if token == t.token { - t.Lock() link.loopback = true - t.Unlock() + loopback = true } + // set as connected + link.connected = true + + // save the link once connected + t.Lock() + t.links[link.Remote()] = link + t.Unlock() + // nothing more to do continue case "close": @@ -258,6 +298,11 @@ func (t *tun) listen(link *link) { continue } + // if its not connected throw away the link + if !link.connected { + return + } + // strip message header delete(msg.Header, "Micro-Tunnel") @@ -283,8 +328,10 @@ func (t *tun) listen(link *link) { var s *socket var exists bool + // If its a loopback connection then we've enabled link direction + // listening side is used for listening, the dialling side for dialling switch { - case link.loopback: + case loopback: s, exists = t.getSocket(id, "listener") default: // get the socket based on the tunnel id and session @@ -298,6 +345,7 @@ func (t *tun) listen(link *link) { s, exists = t.getSocket(id, "listener") } } + // bail if no socket has been found if !exists { log.Debugf("Tunnel skipping no socket exists") @@ -337,9 +385,10 @@ func (t *tun) listen(link *link) { // construct the internal message imsg := &message{ - id: id, - session: session, - data: tmsg, + id: id, + session: session, + data: tmsg, + loopback: loopback, } // append to recv backlog @@ -399,13 +448,14 @@ func (t *tun) setupLink(node string) (*link, error) { return nil, err } - // save the link - id := uuid.New().String() + // create a new link link := &link{ Socket: c, - id: id, + id: uuid.New().String(), + // we made the outbound connection + // and sent the connect message + connected: true, } - t.links[node] = link // process incoming messages go t.listen(link) @@ -430,25 +480,16 @@ func (t *tun) connect() error { // accept inbound connections err := l.Accept(func(sock transport.Socket) { log.Debugf("Tunnel accepted connection from %s", sock.Remote()) - // save the link - id := uuid.New().String() - t.Lock() + + // create a new link link := &link{ Socket: sock, - id: id, + id: uuid.New().String(), } - t.links[sock.Remote()] = link - t.Unlock() - // delete the link - defer func() { - log.Debugf("Tunnel deleting connection from %s", sock.Remote()) - t.Lock() - delete(t.links, sock.Remote()) - t.Unlock() - }() - - // listen for inbound messages + // listen for inbound messages. + // only save the link once connected. + // we do this inside liste t.listen(link) }) @@ -473,6 +514,7 @@ func (t *tun) connect() error { log.Debugf("Tunnel failed to establish node link to %s: %v", node, err) continue } + // save the link t.links[node] = link } diff --git a/tunnel/listener.go b/tunnel/listener.go index 3002e7b6..b953601b 100644 --- a/tunnel/listener.go +++ b/tunnel/listener.go @@ -41,6 +41,8 @@ func (t *tunListener) process() { id: m.id, // the session id session: m.session, + // is loopback conn + loopback: m.loopback, // close chan closed: make(chan bool), // recv called by the acceptor diff --git a/tunnel/socket.go b/tunnel/socket.go index 2590a48e..921d4cf0 100644 --- a/tunnel/socket.go +++ b/tunnel/socket.go @@ -25,8 +25,10 @@ type socket struct { recv chan *message // wait until we have a connection wait chan bool - // outbound marks the socket as outbound + // outbound marks the socket as outbound dialled connection outbound bool + // lookback marks the socket as a loopback on the inbound + loopback bool } // message is sent over the send channel @@ -37,6 +39,8 @@ type message struct { session string // outbound marks the message as outbound outbound bool + // loopback marks the message intended for loopback + loopback bool // transport data data *transport.Message } @@ -80,6 +84,7 @@ func (s *socket) Send(m *transport.Message) error { id: s.id, session: s.session, outbound: s.outbound, + loopback: s.loopback, data: data, } log.Debugf("Appending %+v to send backlog", msg)