From 7b1f5584abfe3cd36534fbbf1a133a39be80575f Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Tue, 15 Oct 2019 15:40:04 +0100 Subject: [PATCH] Tunnel mode --- tunnel/broker/broker.go | 2 +- tunnel/default.go | 25 +++++++++++++++--------- tunnel/listener.go | 4 ++-- tunnel/options.go | 20 ++++++++++++++++--- tunnel/session.go | 43 +++++++++++++++++++++-------------------- tunnel/tunnel.go | 10 +++++++++- 6 files changed, 67 insertions(+), 37 deletions(-) diff --git a/tunnel/broker/broker.go b/tunnel/broker/broker.go index 6778dfaa..75779698 100644 --- a/tunnel/broker/broker.go +++ b/tunnel/broker/broker.go @@ -71,7 +71,7 @@ func (t *tunBroker) Publish(topic string, m *broker.Message, opts ...broker.Publ } func (t *tunBroker) Subscribe(topic string, h broker.Handler, opts ...broker.SubscribeOption) (broker.Subscriber, error) { - l, err := t.tunnel.Listen(topic) + l, err := t.tunnel.Listen(topic, tunnel.ListenMulticast()) if err != nil { return nil, err } diff --git a/tunnel/default.go b/tunnel/default.go index 16ca339e..fe63fe26 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -326,7 +326,7 @@ func (t *tun) process() { } // check the multicast mappings - if msg.multicast { + if msg.mode > Unicast { link.RLock() _, ok := link.channels[msg.channel] link.RUnlock() @@ -366,7 +366,7 @@ func (t *tun) process() { sent = true // keep sending broadcast messages - if msg.broadcast || msg.multicast { + if msg.mode > Unicast { continue } @@ -523,7 +523,7 @@ func (t *tun) listen(link *link) { case "accept": s, exists := t.getSession(channel, sessionId) // we don't need this - if exists && s.multicast { + if exists && s.mode > Unicast { s.accepted = true continue } @@ -963,7 +963,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { } // set the multicast option - c.multicast = options.Multicast + c.mode = options.Mode // set the dial timeout c.timeout = options.Timeout @@ -1009,7 +1009,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // discovered so set the link if not multicast // TODO: pick the link efficiently based // on link status and saturation. - if c.discovered && !c.multicast { + if c.discovered && c.mode == Unicast { // set the link i := rand.Intn(len(links)) c.link = links[i] @@ -1019,7 +1019,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { if !c.discovered { // create a new discovery message for this channel msg := c.newMessage("discover") - msg.broadcast = true + msg.mode = Broadcast msg.outbound = true msg.link = "" @@ -1041,7 +1041,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { dialTimeout := after() // set a shorter delay for multicast - if c.multicast { + if c.mode > Unicast { // shorten this dialTimeout = time.Millisecond * 500 } @@ -1057,7 +1057,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { } // if its multicast just go ahead because this is best effort - if c.multicast { + if c.mode > Unicast { c.discovered = true c.accepted = true return c, nil @@ -1086,9 +1086,14 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { } // Accept a connection on the address -func (t *tun) Listen(channel string) (Listener, error) { +func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { log.Debugf("Tunnel listening on %s", channel) + var options ListenOptions + for _, o := range opts { + o(&options) + } + // create a new session by hashing the address c, ok := t.newSession(channel, "listener") if !ok { @@ -1103,6 +1108,8 @@ func (t *tun) Listen(channel string) (Listener, error) { c.remote = "remote" // set local c.local = channel + // set mode + c.mode = options.Mode tl := &tunListener{ channel: channel, diff --git a/tunnel/listener.go b/tunnel/listener.go index f154b2a6..c893297d 100644 --- a/tunnel/listener.go +++ b/tunnel/listener.go @@ -82,8 +82,8 @@ func (t *tunListener) process() { loopback: m.loopback, // the link the message was received on link: m.link, - // set multicast - multicast: m.multicast, + // set the connection mode + mode: m.mode, // close chan closed: make(chan bool), // recv called by the acceptor diff --git a/tunnel/options.go b/tunnel/options.go index 903f7fb7..54e740e2 100644 --- a/tunnel/options.go +++ b/tunnel/options.go @@ -36,12 +36,19 @@ type DialOption func(*DialOptions) type DialOptions struct { // Link specifies the link to use Link string - // specify a multicast connection - Multicast bool + // specify mode of the session + Mode Mode // the dial timeout Timeout time.Duration } +type ListenOption func(*ListenOptions) + +type ListenOptions struct { + // specify mode of the session + Mode Mode +} + // The tunnel id func Id(id string) Option { return func(o *Options) { @@ -87,12 +94,19 @@ func DefaultOptions() Options { } } +// Listen options +func ListenMulticast() ListenOption { + return func(o *ListenOptions) { + o.Mode = Multicast + } +} + // Dial options // Dial multicast sets the multicast option to send only to those mapped func DialMulticast() DialOption { return func(o *DialOptions) { - o.Multicast = true + o.Mode = Multicast } } diff --git a/tunnel/session.go b/tunnel/session.go index 6757f150..0cd3bcce 100644 --- a/tunnel/session.go +++ b/tunnel/session.go @@ -37,10 +37,8 @@ type session struct { outbound bool // lookback marks the session as a loopback on the inbound loopback bool - // if the session is multicast - multicast bool - // if the session is broadcast - broadcast bool + // mode of the connection + mode Mode // the timeout timeout time.Duration // the link on which this message was received @@ -63,10 +61,8 @@ type message struct { outbound bool // loopback marks the message intended for loopback loopback bool - // whether to send as multicast - multicast bool - // broadcast sets the broadcast type - broadcast bool + // mode of the connection + mode Mode // the link to send the message on link string // transport data @@ -98,15 +94,15 @@ func (s *session) Channel() string { // newMessage creates a new message based on the session func (s *session) newMessage(typ string) *message { return &message{ - typ: typ, - tunnel: s.tunnel, - channel: s.channel, - session: s.session, - outbound: s.outbound, - loopback: s.loopback, - multicast: s.multicast, - link: s.link, - errChan: s.errChan, + typ: typ, + tunnel: s.tunnel, + channel: s.channel, + session: s.session, + outbound: s.outbound, + loopback: s.loopback, + mode: s.mode, + link: s.link, + errChan: s.errChan, } } @@ -128,8 +124,8 @@ func (s *session) Open() error { return io.EOF } - // we don't wait on multicast - if s.multicast { + // don't wait on multicast/broadcast + if s.mode > Unicast { s.accepted = true return nil } @@ -166,6 +162,11 @@ func (s *session) Accept() error { // no op here } + // don't wait on multicast/broadcast + if s.mode > Unicast { + return nil + } + // wait for send response select { case err := <-s.errChan: @@ -185,7 +186,7 @@ func (s *session) Announce() error { // we don't need an error back msg.errChan = nil // announce to all - msg.broadcast = true + msg.mode = Broadcast // we don't need the link msg.link = "" @@ -222,7 +223,7 @@ func (s *session) Send(m *transport.Message) error { msg.data = data // if multicast don't set the link - if s.multicast { + if s.mode > Unicast { msg.link = "" } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 312a6681..8b9b347a 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -8,6 +8,12 @@ import ( "github.com/micro/go-micro/transport" ) +const ( + Unicast Mode = iota + Multicast + Broadcast +) + var ( // DefaultDialTimeout is the dial timeout if none is specified DefaultDialTimeout = time.Second * 5 @@ -19,6 +25,8 @@ var ( ErrLinkNotFound = errors.New("link not found") ) +type Mode uint8 + // Tunnel creates a gre tunnel on top of the go-micro/transport. // It establishes multiple streams using the Micro-Tunnel-Channel header // and Micro-Tunnel-Session header. The tunnel id is a hash of @@ -36,7 +44,7 @@ type Tunnel interface { // Connect to a channel Dial(channel string, opts ...DialOption) (Session, error) // Accept connections on a channel - Listen(channel string) (Listener, error) + Listen(channel string, opts ...ListenOption) (Listener, error) // Name of the tunnel implementation String() string }