mirror of
https://github.com/go-micro/go-micro.git
synced 2024-12-30 10:10:44 +02:00
MDNS registry fix for users on VPNs (#1759)
* filter out unsolicited responses * send to local ip in case * allow ip func to be passed in. add option for sending to 0.0.0.0
This commit is contained in:
parent
3e6ac73cfe
commit
7be4a67673
@ -252,7 +252,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
srv, err := mdns.NewServer(&mdns.Config{Zone: s})
|
srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
gerr = err
|
gerr = err
|
||||||
continue
|
continue
|
||||||
|
@ -34,6 +34,7 @@ type ServiceEntry struct {
|
|||||||
|
|
||||||
// complete is used to check if we have all the info we need
|
// complete is used to check if we have all the info we need
|
||||||
func (s *ServiceEntry) complete() bool {
|
func (s *ServiceEntry) complete() bool {
|
||||||
|
|
||||||
return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT
|
return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -347,15 +348,21 @@ func (c *client) query(params *QueryParam) error {
|
|||||||
select {
|
select {
|
||||||
case resp := <-msgCh:
|
case resp := <-msgCh:
|
||||||
inp := messageToEntry(resp, inprogress)
|
inp := messageToEntry(resp, inprogress)
|
||||||
|
|
||||||
if inp == nil {
|
if inp == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name {
|
||||||
|
// discard anything which we've not asked for
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Check if this entry is complete
|
// Check if this entry is complete
|
||||||
if inp.complete() {
|
if inp.complete() {
|
||||||
if inp.sent {
|
if inp.sent {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
inp.sent = true
|
inp.sent = true
|
||||||
select {
|
select {
|
||||||
case params.Entries <- inp:
|
case params.Entries <- inp:
|
||||||
|
@ -2,13 +2,13 @@ package mdns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/micro/go-micro/v2/logger"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
@ -39,6 +39,10 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetMachineIP is a func which returns the outbound IP of this machine.
|
||||||
|
// Used by the server to determine whether to attempt send the response on a local address
|
||||||
|
type GetMachineIP func() net.IP
|
||||||
|
|
||||||
// Config is used to configure the mDNS server
|
// Config is used to configure the mDNS server
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// Zone must be provided to support responding to queries
|
// Zone must be provided to support responding to queries
|
||||||
@ -51,9 +55,15 @@ type Config struct {
|
|||||||
|
|
||||||
// Port If it is not 0, replace the port 5353 with this port number.
|
// Port If it is not 0, replace the port 5353 with this port number.
|
||||||
Port int
|
Port int
|
||||||
|
|
||||||
|
// GetMachineIP is a function to return the IP of the local machine
|
||||||
|
GetMachineIP GetMachineIP
|
||||||
|
// LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP
|
||||||
|
// is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports
|
||||||
|
LocalhostChecking bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// mDNS server is used to listen for mDNS queries and respond if we
|
// Server is an mDNS server used to listen for mDNS queries and respond if we
|
||||||
// have a matching local record
|
// have a matching local record
|
||||||
type Server struct {
|
type Server struct {
|
||||||
config *Config
|
config *Config
|
||||||
@ -65,6 +75,8 @@ type Server struct {
|
|||||||
shutdownCh chan struct{}
|
shutdownCh chan struct{}
|
||||||
shutdownLock sync.Mutex
|
shutdownLock sync.Mutex
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
outboundIP net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer is used to create a new mDNS server from a config
|
// NewServer is used to create a new mDNS server from a config
|
||||||
@ -118,11 +130,17 @@ func NewServer(config *Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ipFunc := getOutboundIP
|
||||||
|
if config.GetMachineIP != nil {
|
||||||
|
ipFunc = config.GetMachineIP
|
||||||
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
config: config,
|
config: config,
|
||||||
ipv4List: ipv4List,
|
ipv4List: ipv4List,
|
||||||
ipv6List: ipv6List,
|
ipv6List: ipv6List,
|
||||||
shutdownCh: make(chan struct{}),
|
shutdownCh: make(chan struct{}),
|
||||||
|
outboundIP: ipFunc(),
|
||||||
}
|
}
|
||||||
|
|
||||||
go s.recv(s.ipv4List)
|
go s.recv(s.ipv4List)
|
||||||
@ -176,7 +194,7 @@ func (s *Server) recv(c *net.UDPConn) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := s.parsePacket(buf[:n], from); err != nil {
|
if err := s.parsePacket(buf[:n], from); err != nil {
|
||||||
log.Printf("[ERR] mdns: Failed to handle query: %v", err)
|
log.Errorf("[ERR] mdns: Failed to handle query: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -185,7 +203,7 @@ func (s *Server) recv(c *net.UDPConn) {
|
|||||||
func (s *Server) parsePacket(packet []byte, from net.Addr) error {
|
func (s *Server) parsePacket(packet []byte, from net.Addr) error {
|
||||||
var msg dns.Msg
|
var msg dns.Msg
|
||||||
if err := msg.Unpack(packet); err != nil {
|
if err := msg.Unpack(packet); err != nil {
|
||||||
log.Printf("[ERR] mdns: Failed to unpack packet: %v", err)
|
log.Errorf("[ERR] mdns: Failed to unpack packet: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: This is a bit of a hack
|
// TODO: This is a bit of a hack
|
||||||
@ -278,8 +296,8 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
|
|||||||
// caveats in the RFC), so set the Compress bit (part of the dns library
|
// caveats in the RFC), so set the Compress bit (part of the dns library
|
||||||
// API, not part of the DNS packet) to true.
|
// API, not part of the DNS packet) to true.
|
||||||
Compress: true,
|
Compress: true,
|
||||||
|
Question: query.Question,
|
||||||
Answer: answer,
|
Answer: answer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -302,7 +320,6 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
|
|||||||
// both. The return values are DNS records for each transmission type.
|
// both. The return values are DNS records for each transmission type.
|
||||||
func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) {
|
func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) {
|
||||||
records := s.config.Zone.Records(q)
|
records := s.config.Zone.Records(q)
|
||||||
|
|
||||||
if len(records) == 0 {
|
if len(records) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -365,7 +382,7 @@ func (s *Server) probe() {
|
|||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
if err := s.SendMulticast(q); err != nil {
|
if err := s.SendMulticast(q); err != nil {
|
||||||
log.Println("[ERR] mdns: failed to send probe:", err.Error())
|
log.Errorf("[ERR] mdns: failed to send probe:", err.Error())
|
||||||
}
|
}
|
||||||
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
|
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
|
||||||
}
|
}
|
||||||
@ -391,7 +408,7 @@ func (s *Server) probe() {
|
|||||||
timer := time.NewTimer(timeout)
|
timer := time.NewTimer(timeout)
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
if err := s.SendMulticast(resp); err != nil {
|
if err := s.SendMulticast(resp); err != nil {
|
||||||
log.Println("[ERR] mdns: failed to send announcement:", err.Error())
|
log.Errorf("[ERR] mdns: failed to send announcement:", err.Error())
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
@ -404,7 +421,7 @@ func (s *Server) probe() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// multicastResponse us used to send a multicast response packet
|
// SendMulticast us used to send a multicast response packet
|
||||||
func (s *Server) SendMulticast(msg *dns.Msg) error {
|
func (s *Server) SendMulticast(msg *dns.Msg) error {
|
||||||
buf, err := msg.Pack()
|
buf, err := msg.Pack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -430,13 +447,23 @@ func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error {
|
|||||||
|
|
||||||
// Determine the socket to send from
|
// Determine the socket to send from
|
||||||
addr := from.(*net.UDPAddr)
|
addr := from.(*net.UDPAddr)
|
||||||
if addr.IP.To4() != nil {
|
conn := s.ipv4List
|
||||||
_, err = s.ipv4List.WriteToUDP(buf, addr)
|
backupTarget := net.IPv4zero
|
||||||
return err
|
|
||||||
} else {
|
if addr.IP.To4() == nil {
|
||||||
_, err = s.ipv6List.WriteToUDP(buf, addr)
|
conn = s.ipv6List
|
||||||
return err
|
backupTarget = net.IPv6zero
|
||||||
}
|
}
|
||||||
|
_, err = conn.WriteToUDP(buf, addr)
|
||||||
|
// If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0
|
||||||
|
// This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there
|
||||||
|
// Sending two responses is OK
|
||||||
|
if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) {
|
||||||
|
// ignore any errors, this is best efforts
|
||||||
|
conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port})
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) unregister() error {
|
func (s *Server) unregister() error {
|
||||||
@ -474,3 +501,17 @@ func setCustomPort(port int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getOutboundIP returns the IP address of this machine as seen when dialling out
|
||||||
|
func getOutboundIP() net.IP {
|
||||||
|
conn, err := net.Dial("udp", "8.8.8.8:80")
|
||||||
|
if err != nil {
|
||||||
|
// no net connectivity maybe so fallback
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
return localAddr.IP
|
||||||
|
}
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
func TestServer_StartStop(t *testing.T) {
|
func TestServer_StartStop(t *testing.T) {
|
||||||
s := makeService(t)
|
s := makeService(t)
|
||||||
serv, err := NewServer(&Config{Zone: s})
|
serv, err := NewServer(&Config{Zone: s, LocalhostChecking: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
@ -15,7 +15,7 @@ func TestServer_StartStop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Lookup(t *testing.T) {
|
func TestServer_Lookup(t *testing.T) {
|
||||||
serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp")})
|
serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp"), LocalhostChecking: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user