From e688ab0a452f1d19560ece5015bc46488abe3080 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Wed, 17 Jul 2019 10:38:50 +0300 Subject: [PATCH] fix ipv6 addr parsing and using Signed-off-by: Vasiliy Tolstov --- broker/http_broker.go | 19 +++++++++----- client/grpc/grpc_test.go | 9 +------ server/grpc/grpc.go | 44 +++++++++++++++++++------------ server/rpc_server.go | 45 ++++++++++++++++++++------------ transport/grpc/grpc_test.go | 8 +++--- transport/http_transport_test.go | 8 +++--- util/addr/addr.go | 13 +++++---- util/net/net.go | 26 ++++++++++-------- 8 files changed, 104 insertions(+), 68 deletions(-) diff --git a/broker/http_broker.go b/broker/http_broker.go index d28eeb2c..db4d78e7 100644 --- a/broker/http_broker.go +++ b/broker/http_broker.go @@ -13,8 +13,6 @@ import ( "net/http" "net/url" "runtime" - "strconv" - "strings" "sync" "time" @@ -614,18 +612,27 @@ func (h *httpBroker) Publish(topic string, msg *Message, opts ...PublishOption) } func (h *httpBroker) Subscribe(topic string, handler Handler, opts ...SubscribeOption) (Subscriber, error) { + var err error + var host, port string options := NewSubscribeOptions(opts...) // parse address for host, port - parts := strings.Split(h.Address(), ":") - host := strings.Join(parts[:len(parts)-1], ":") - port, _ := strconv.Atoi(parts[len(parts)-1]) + host, port, err = net.SplitHostPort(h.Address()) + if err != nil { + return nil, err + } addr, err := maddr.Extract(host) if err != nil { return nil, err } + // ipv6 addr + if addr == "::" { + // ipv6 addr + addr = fmt.Sprintf("[%s]", addr) + } + // create unique id id := h.id + "." + uuid.New().String() @@ -638,7 +645,7 @@ func (h *httpBroker) Subscribe(topic string, handler Handler, opts ...SubscribeO // register service node := ®istry.Node{ Id: id, - Address: fmt.Sprintf("%s:%d", addr, port), + Address: fmt.Sprintf("%s:%s", addr, port), Metadata: map[string]string{ "secure": fmt.Sprintf("%t", secure), }, diff --git a/client/grpc/grpc_test.go b/client/grpc/grpc_test.go index 3652ec6b..5696a267 100644 --- a/client/grpc/grpc_test.go +++ b/client/grpc/grpc_test.go @@ -2,10 +2,7 @@ package grpc import ( "context" - "fmt" "net" - "strconv" - "strings" "testing" "github.com/micro/go-micro/client" @@ -37,10 +34,6 @@ func TestGRPCClient(t *testing.T) { go s.Serve(l) defer s.Stop() - parts := strings.Split(l.Addr().String(), ":") - port, _ := strconv.Atoi(parts[len(parts)-1]) - addr := strings.Join(parts[:len(parts)-1], ":") - // create mock registry r := memory.NewRegistry() @@ -51,7 +44,7 @@ func TestGRPCClient(t *testing.T) { Nodes: []*registry.Node{ ®istry.Node{ Id: "test-1", - Address: fmt.Sprintf("%s:%d", addr, port), + Address: l.Addr().String(), }, }, }) diff --git a/server/grpc/grpc.go b/server/grpc/grpc.go index 8d2a2d8f..b3473efc 100644 --- a/server/grpc/grpc.go +++ b/server/grpc/grpc.go @@ -504,10 +504,11 @@ func (g *grpcServer) Subscribe(sb server.Subscriber) error { } func (g *grpcServer) Register() error { + var err error + var advt, host, port string + // parse address for host, port config := g.opts - var advt, host string - var port int // check the advertise address first // if it exists then use it, otherwise @@ -518,12 +519,17 @@ func (g *grpcServer) Register() error { advt = config.Address } - parts := strings.Split(advt, ":") - if len(parts) > 1 { - host = strings.Join(parts[:len(parts)-1], ":") - port, _ = strconv.Atoi(parts[len(parts)-1]) + if idx := strings.Count(advt, ":"); idx > 1 { + // ipv6 address in format [host]:port or ipv4 host:port + host, port, err = net.SplitHostPort(advt) + if err != nil { + return err + } + if host == "::" { + host = fmt.Sprintf("[%s]", host) + } } else { - host = parts[0] + host = advt } addr, err := addr.Extract(host) @@ -534,7 +540,7 @@ func (g *grpcServer) Register() error { // register service node := ®istry.Node{ Id: config.Name + "-" + config.Id, - Address: fmt.Sprintf("%s:%d", addr, port), + Address: fmt.Sprintf("%s:%s", addr, port), Metadata: config.Metadata, } @@ -629,9 +635,10 @@ func (g *grpcServer) Register() error { } func (g *grpcServer) Deregister() error { + var err error + var advt, host, port string + config := g.opts - var advt, host string - var port int // check the advertise address first // if it exists then use it, otherwise @@ -642,12 +649,17 @@ func (g *grpcServer) Deregister() error { advt = config.Address } - parts := strings.Split(advt, ":") - if len(parts) > 1 { - host = strings.Join(parts[:len(parts)-1], ":") - port, _ = strconv.Atoi(parts[len(parts)-1]) + if idx := strings.Count(advt, ":"); idx > 1 { + // ipv6 address in format [host]:port or ipv4 host:port + host, port, err = net.SplitHostPort(advt) + if err != nil { + return err + } + if host == "::" { + host = fmt.Sprintf("[%s]", host) + } } else { - host = parts[0] + host = advt } addr, err := addr.Extract(host) @@ -657,7 +669,7 @@ func (g *grpcServer) Deregister() error { node := ®istry.Node{ Id: config.Name + "-" + config.Id, - Address: fmt.Sprintf("%s:%d", addr, port), + Address: fmt.Sprintf("%s:%s", addr, port), } service := ®istry.Service{ diff --git a/server/rpc_server.go b/server/rpc_server.go index 533def1e..392f64b1 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "net" "runtime/debug" "sort" "strconv" @@ -277,10 +278,11 @@ func (s *rpcServer) Subscribe(sb Subscriber) error { } func (s *rpcServer) Register() error { + var err error + var advt, host, port string + // parse address for host, port config := s.Options() - var advt, host string - var port int // check the advertise address first // if it exists then use it, otherwise @@ -291,12 +293,17 @@ func (s *rpcServer) Register() error { advt = config.Address } - parts := strings.Split(advt, ":") - if len(parts) > 1 { - host = strings.Join(parts[:len(parts)-1], ":") - port, _ = strconv.Atoi(parts[len(parts)-1]) + if idx := strings.Count(advt, ":"); idx > 1 { + // ipv6 address in format [host]:port or ipv4 host:port + host, port, err = net.SplitHostPort(advt) + if err != nil { + return err + } + if host == "::" { + host = fmt.Sprintf("[%s]", host) + } } else { - host = parts[0] + host = advt } addr, err := addr.Extract(host) @@ -313,7 +320,7 @@ func (s *rpcServer) Register() error { // register service node := ®istry.Node{ Id: config.Name + "-" + config.Id, - Address: fmt.Sprintf("%s:%d", addr, port), + Address: fmt.Sprintf("%s:%s", addr, port), Metadata: md, } @@ -413,9 +420,10 @@ func (s *rpcServer) Register() error { } func (s *rpcServer) Deregister() error { + var err error + var advt, host, port string + config := s.Options() - var advt, host string - var port int // check the advertise address first // if it exists then use it, otherwise @@ -426,12 +434,17 @@ func (s *rpcServer) Deregister() error { advt = config.Address } - parts := strings.Split(advt, ":") - if len(parts) > 1 { - host = strings.Join(parts[:len(parts)-1], ":") - port, _ = strconv.Atoi(parts[len(parts)-1]) + if idx := strings.Count(advt, ":"); idx > 1 { + // ipv6 address in format [host]:port or ipv4 host:port + host, port, err = net.SplitHostPort(advt) + if err != nil { + return err + } + if host == "::" { + host = fmt.Sprintf("[%s]", host) + } } else { - host = parts[0] + host = advt } addr, err := addr.Extract(host) @@ -441,7 +454,7 @@ func (s *rpcServer) Deregister() error { node := ®istry.Node{ Id: config.Name + "-" + config.Id, - Address: fmt.Sprintf("%s:%d", addr, port), + Address: fmt.Sprintf("%s:%s", addr, port), } service := ®istry.Service{ diff --git a/transport/grpc/grpc_test.go b/transport/grpc/grpc_test.go index d4e82346..b1e46ba5 100644 --- a/transport/grpc/grpc_test.go +++ b/transport/grpc/grpc_test.go @@ -1,15 +1,17 @@ package grpc import ( - "strings" + "net" "testing" "github.com/micro/go-micro/transport" ) func expectedPort(t *testing.T, expected string, lsn transport.Listener) { - parts := strings.Split(lsn.Addr(), ":") - port := parts[len(parts)-1] + _, port, err := net.SplitHostPort(lsn.Addr()) + if err != nil { + t.Errorf("Expected address to be `%s`, got error: %v", expected, err) + } if port != expected { lsn.Close() diff --git a/transport/http_transport_test.go b/transport/http_transport_test.go index fac80716..cbbc8658 100644 --- a/transport/http_transport_test.go +++ b/transport/http_transport_test.go @@ -2,14 +2,16 @@ package transport import ( "io" - "strings" + "net" "testing" "time" ) func expectedPort(t *testing.T, expected string, lsn Listener) { - parts := strings.Split(lsn.Addr(), ":") - port := parts[len(parts)-1] + _, port, err := net.SplitHostPort(lsn.Addr()) + if err != nil { + t.Errorf("Expected address to be `%s`, got error: %v", expected, err) + } if port != expected { lsn.Close() diff --git a/util/addr/addr.go b/util/addr/addr.go index ab3acca1..b2874533 100644 --- a/util/addr/addr.go +++ b/util/addr/addr.go @@ -30,7 +30,7 @@ func isPrivateIP(ipAddr string) bool { // Extract returns a real ip func Extract(addr string) (string, error) { // if addr specified then its returned - if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]") { + if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]" && addr != "::") { return addr, nil } @@ -113,10 +113,13 @@ func IPs() []string { continue } - ip = ip.To4() - if ip == nil { - continue - } + // dont skip ipv6 addrs + /* + ip = ip.To4() + if ip == nil { + continue + } + */ ipAddrs = append(ipAddrs, ip.String()) } diff --git a/util/net/net.go b/util/net/net.go index b092068f..f9726f6b 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -11,39 +11,43 @@ import ( // Listen takes addr:portmin-portmax and binds to the first available port // Example: Listen("localhost:5000-6000", fn) func Listen(addr string, fn func(string) (net.Listener, error)) (net.Listener, error) { - // host:port || host:min-max - parts := strings.Split(addr, ":") - // - if len(parts) < 2 { + if strings.Count(addr, ":") == 1 && strings.Count(addr, "-") == 0 { return fn(addr) } + // host:port || host:min-max + host, ports, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + if host == "::" { + host = fmt.Sprintf("[%s]", host) + } + // try to extract port range - ports := strings.Split(parts[len(parts)-1], "-") + prange := strings.Split(ports, "-") // single port - if len(ports) < 2 { + if len(prange) < 2 { return fn(addr) } // we have a port range // extract min port - min, err := strconv.Atoi(ports[0]) + min, err := strconv.Atoi(prange[0]) if err != nil { return nil, errors.New("unable to extract port range") } // extract max port - max, err := strconv.Atoi(ports[1]) + max, err := strconv.Atoi(prange[1]) if err != nil { return nil, errors.New("unable to extract port range") } - // set host - host := parts[:len(parts)-1] - // range the ports for port := min; port <= max; port++ { // try bind to host:port