From b4236f44304a76b82ae06a3bca7d932acc19c52f Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Tue, 2 Jul 2019 00:27:53 +0100 Subject: [PATCH] Add network transport --- transport/network/listener.go | 44 ++++ transport/network/network.go | 339 +++++++++++++++++++++++++++++++ transport/network/socket.go | 80 ++++++++ transport/network/socket_test.go | 61 ++++++ 4 files changed, 524 insertions(+) create mode 100644 transport/network/listener.go create mode 100644 transport/network/network.go create mode 100644 transport/network/socket.go create mode 100644 transport/network/socket_test.go diff --git a/transport/network/listener.go b/transport/network/listener.go new file mode 100644 index 00000000..6e8efc60 --- /dev/null +++ b/transport/network/listener.go @@ -0,0 +1,44 @@ +package network + +import ( + "github.com/micro/go-micro/transport" +) + +type listener struct { + // stream id + id string + // address of the listener + addr string + // close channel + closed chan bool + // accept socket + accept chan *socket +} + +func (n *listener) Addr() string { + return n.addr +} + +func (n *listener) Close() error { + select { + case <-n.closed: + default: + close(n.closed) + } + return nil +} + +func (n *listener) Accept(fn func(s transport.Socket)) error { + for { + select { + case <-n.closed: + return nil + case s, ok := <-n.accept: + if !ok { + return nil + } + go fn(s) + } + } + return nil +} diff --git a/transport/network/network.go b/transport/network/network.go new file mode 100644 index 00000000..791edbe7 --- /dev/null +++ b/transport/network/network.go @@ -0,0 +1,339 @@ +// Package network provides a network transport +package network + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "sync" + + "github.com/micro/go-micro/network" + "github.com/micro/go-micro/transport" +) + +type networkKey struct{} + +type Transport struct { + options transport.Options + + // the network interface + network network.Network + + // protect all the things + sync.RWMutex + + // connect + connected bool + // connected node + node network.Node + // the send channel + send chan *message + // close channel + closed chan bool + + // sockets + sockets map[string]*socket + // listeners + listeners map[string]*listener +} + +func (n *Transport) newListener(addr string) *listener { + // hash the id + h := sha256.New() + h.Write([]byte(addr)) + id := fmt.Sprintf("%x", h.Sum(nil)) + + // create the listener + l := &listener{ + id: id, + addr: addr, + closed: make(chan bool), + accept: make(chan *socket, 128), + } + + // save it + n.Lock() + n.listeners[id] = l + n.Unlock() + + return l +} + +func (n *Transport) getListener(id string) (*listener, bool) { + // get the listener + n.RLock() + s, ok := n.listeners[id] + n.RUnlock() + return s, ok +} + +func (n *Transport) getSocket(id string) (*socket, bool) { + // get the socket + n.RLock() + s, ok := n.sockets[id] + n.RUnlock() + return s, ok +} + +func (n *Transport) newSocket(id string) *socket { + // hash the id + h := sha256.New() + h.Write([]byte(id)) + id = fmt.Sprintf("%x", h.Sum(nil)) + + // new socket + s := &socket{ + id: id, + closed: make(chan bool), + recv: make(chan *message, 128), + send: n.send, + } + + // save socket + n.Lock() + n.sockets[id] = s + n.Unlock() + + // return socket + return s +} + +// process outgoing messages +func (n *Transport) process() { + // manage the send buffer + // all pseudo sockets throw everything down this + for { + select { + case msg := <-n.send: + netmsg := &network.Message{ + Header: msg.data.Header, + Body: msg.data.Body, + } + + // set the stream id on the outgoing message + netmsg.Header["Micro-Stream"] = msg.id + + // send the message via the interface + if err := n.node.Send(netmsg); err != nil { + // no op + // TODO: do something + } + case <-n.closed: + return + } + } +} + +// process incoming messages +func (n *Transport) listen() { + for { + // process anything via the net interface + msg, err := n.node.Accept() + if err != nil { + return + } + + // a stream id + id := msg.Header["Micro-Stream"] + + // get the socket + s, exists := n.getSocket(id) + if !exists { + // get the listener + l, ok := n.getListener(id) + // there's no socket and there's no listener + if !ok { + continue + } + + // listener is closed + select { + case <-l.closed: + // delete it + n.Lock() + delete(n.listeners, l.id) + n.Unlock() + continue + default: + } + + // no socket, create one + s = n.newSocket(id) + // set remote address + s.remote = msg.Header["Remote"] + + // drop that to the listener + // TODO: non blocking + l.accept <- s + } + + // is the socket closed? + select { + case <-s.closed: + // closed + delete(n.sockets, id) + continue + default: + // process + } + + tmsg := &transport.Message{ + Header: msg.Header, + Body: msg.Body, + } + + // TODO: don't block on queuing + // append to recv backlog + s.recv <- &message{id: id, data: tmsg} + } +} + +func (n *Transport) Init(opts ...transport.Option) error { + for _, o := range opts { + o(&n.options) + } + return nil +} + +func (n *Transport) Options() transport.Options { + return n.options +} + +// Close the tunnel +func (n *Transport) Close() error { + n.Lock() + defer n.Unlock() + + if !n.connected { + return nil + } + + select { + case <-n.closed: + return nil + default: + // close all the sockets + for _, s := range n.sockets { + s.Close() + } + for _, l := range n.listeners { + l.Close() + } + // close the connection + close(n.closed) + // close node connection + n.node.Close() + // reset connected + n.connected = false + } + + return nil +} + +// Connect the tunnel +func (n *Transport) Connect() error { + n.Lock() + defer n.Unlock() + + // already connected + if n.connected { + return nil + } + + // get a new node + node, err := n.network.Connect() + if err != nil { + return err + } + + // set as connected + n.connected = true + // create new close channel + n.closed = make(chan bool) + // save node + n.node = node + + // process messages to be sent + go n.process() + // process incoming messages + go n.listen() + + return nil +} + +// Dial an address +func (n *Transport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) { + if err := n.Connect(); err != nil { + return nil, err + } + + // create new socket + s := n.newSocket(addr) + // set remote + s.remote = addr + // set local + n.RLock() + s.local = n.node.Address() + n.RUnlock() + + return s, nil +} + +func (n *Transport) Listen(addr string, opts ...transport.ListenOption) (transport.Listener, error) { + // check existing listeners + n.RLock() + for _, l := range n.listeners { + if l.addr == addr { + n.RUnlock() + return nil, errors.New("already listening on " + addr) + } + } + n.RUnlock() + + // try to connect to the network + if err := n.Connect(); err != nil { + return nil, err + } + + return n.newListener(addr), nil +} + +func (n *Transport) String() string { + return "network" +} + +// NewTransport creates a new network transport +func NewTransport(opts ...transport.Option) transport.Transport { + options := transport.Options{ + Context: context.Background(), + } + + for _, o := range opts { + o(&options) + } + + // get the network interface + n, ok := options.Context.Value(networkKey{}).(network.Network) + if !ok { + n = network.DefaultNetwork + } + + return &Transport{ + options: options, + network: n, + send: make(chan *message, 128), + closed: make(chan bool), + sockets: make(map[string]*socket), + } +} + +// WithNetwork sets the network interface +func WithNetwork(n network.Network) transport.Option { + return func(o *transport.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, networkKey{}, n) + } +} diff --git a/transport/network/socket.go b/transport/network/socket.go new file mode 100644 index 00000000..92de114f --- /dev/null +++ b/transport/network/socket.go @@ -0,0 +1,80 @@ +package network + +import ( + "errors" + + "github.com/micro/go-micro/transport" +) + +// socket is our pseudo socket for transport.Socket +type socket struct { + // socket id based on Micro-Stream + id string + // closed + closed chan bool + // remote addr + remote string + // local addr + local string + // send chan + send chan *message + // recv chan + recv chan *message +} + +// message is sent over the send channel +type message struct { + // socket id + id string + // transport data + data *transport.Message +} + +func (s *socket) Remote() string { + return s.remote +} + +func (s *socket) Local() string { + return s.local +} + +func (s *socket) Id() string { + return s.id +} + +func (s *socket) Send(m *transport.Message) error { + select { + case <-s.closed: + return errors.New("socket is closed") + default: + // no op + } + // append to backlog + s.send <- &message{id: s.id, data: m} + return nil +} + +func (s *socket) Recv(m *transport.Message) error { + select { + case <-s.closed: + return errors.New("socket is closed") + default: + // no op + } + // recv from backlog + msg := <-s.recv + // set message + *m = *msg.data + // return nil + return nil +} + +func (s *socket) Close() error { + select { + case <-s.closed: + // no op + default: + close(s.closed) + } + return nil +} diff --git a/transport/network/socket_test.go b/transport/network/socket_test.go new file mode 100644 index 00000000..8521a3da --- /dev/null +++ b/transport/network/socket_test.go @@ -0,0 +1,61 @@ +package network + +import ( + "testing" + + "github.com/micro/go-micro/transport" +) + +func TestTunnelSocket(t *testing.T) { + s := &socket{ + id: "1", + closed: make(chan bool), + remote: "remote", + local: "local", + send: make(chan *message, 1), + recv: make(chan *message, 1), + } + + // check addresses local and remote + if s.Local() != s.local { + t.Fatalf("Expected s.Local %s got %s", s.local, s.Local()) + } + if s.Remote() != s.remote { + t.Fatalf("Expected s.Remote %s got %s", s.remote, s.Remote()) + } + + // send a message + s.Send(&transport.Message{Header: map[string]string{}}) + + // get sent message + msg := <-s.send + + if msg.id != s.id { + t.Fatalf("Expected sent message id %s got %s", s.id, msg.id) + } + + // recv a message + msg.data.Header["Foo"] = "bar" + s.recv <- msg + + m := new(transport.Message) + s.Recv(m) + + // check header + if m.Header["Foo"] != "bar" { + t.Fatalf("Did not receive correct message %+v", m) + } + + // close the connection + s.Close() + + // check connection + err := s.Send(m) + if err == nil { + t.Fatal("Expected closed connection") + } + err = s.Recv(m) + if err == nil { + t.Fatal("Expected closed connection") + } +}