diff --git a/go.mod b/go.mod index cf373f7a..8a445a11 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.13 require ( github.com/BurntSushi/toml v0.3.1 - github.com/beevik/ntp v0.2.0 github.com/bitly/go-simplejson v0.5.0 github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect github.com/bwmarrin/discordgo v0.20.2 @@ -47,7 +46,6 @@ require ( github.com/lucas-clemente/quic-go v0.14.1 github.com/mholt/certmagic v0.9.3 github.com/micro/cli/v2 v2.1.2 - github.com/micro/mdns v0.3.0 github.com/miekg/dns v1.1.27 github.com/mitchellh/hashstructure v1.0.0 github.com/nats-io/nats-server/v2 v2.1.6 diff --git a/go.sum b/go.sum index 848d8da2..db9453b6 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,6 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPd github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/aws/aws-sdk-go v1.23.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f/go.mod h1:AuiFmCCPBSrqvVMvuqFuk0qogytodnVFVSN5CeJB8Gc= -github.com/beevik/ntp v0.2.0 h1:sGsd+kAXzT0bfVfzJfce04g+dSRfrs+tbQW8lweuYgw= -github.com/beevik/ntp v0.2.0/go.mod h1:hIHWr+l3+/clUnF44zdK+CWW7fO8dR5cIylAQ76NRpg= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -290,9 +288,6 @@ github.com/mholt/certmagic v0.9.3 h1:RmzuNJ5mpFplDbyS41z+gGgE/py24IX6m0nHZ0yNTQU github.com/mholt/certmagic v0.9.3/go.mod h1:nu8jbsbtwK4205EDH/ZUMTKsfYpJA1Q7MKXHfgTihNw= github.com/micro/cli/v2 v2.1.2 h1:43J1lChg/rZCC1rvdqZNFSQDrGT7qfMrtp6/ztpIkEM= github.com/micro/cli/v2 v2.1.2/go.mod h1:EguNh6DAoWKm9nmk+k/Rg0H3lQnDxqzu5x5srOtGtYg= -github.com/micro/mdns v0.3.0 h1:bYycYe+98AXR3s8Nq5qvt6C573uFTDPIYzJemWON0QE= -github.com/micro/mdns v0.3.0/go.mod h1:KJ0dW7KmicXU2BV++qkLlmHYcVv7/hHnbtguSWt9Aoc= -github.com/miekg/dns v1.1.3/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.27 h1:aEH/kqUzUxGJ/UHcEKdJY+ugH6WEzsEBBSPa8zuy1aM= github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= @@ -314,7 +309,6 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/namedotcom/go v0.0.0-20180403034216-08470befbe04/go.mod h1:5sN+Lt1CaY4wsPvgQH/jsuJi4XO2ssZbdsIizr4CVC8= github.com/nats-io/jwt v0.3.2 h1:+RB5hMpXUUA2dfxuhBTEkMOrYmM+gKIZYS1KjSostMI= github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU= -github.com/nats-io/nats-server v1.4.1 h1:Ul1oSOGNV/L8kjr4v6l2f9Yet6WY+LevH1/7cRZ/qyA= github.com/nats-io/nats-server/v2 v2.1.6 h1:qAaHZaS8pRRNQLFaiBA1rq5WynyEGp9DFgmMfoaiXGY= github.com/nats-io/nats-server/v2 v2.1.6/go.mod h1:BL1NOtaBQ5/y97djERRVWNouMW7GT3gxnmbE/eC8u8A= github.com/nats-io/nats.go v1.9.2 h1:oDeERm3NcZVrPpdR/JpGdWHMv3oJ8yY30YwxKq+DU2s= @@ -454,7 +448,6 @@ go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= golang.org/x/crypto v0.0.0-20180621125126-a49355c7e3f8/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190130090550-b01c7a725664/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190219172222-a4c6cb3142f2/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -465,8 +458,6 @@ golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190927123631-a832865fa7ad/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d h1:1ZiEyfaQIg3Qh0EoqpwAakHVhecoE5wlSg5GjnafJGw= -golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -530,7 +521,6 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190129075346-302c3dd5f1cc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190209173611-3b5209105503/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190221075227-b4e8571b14e0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/registry/mdns_registry.go b/registry/mdns_registry.go index 7e614840..dfc8e6f0 100644 --- a/registry/mdns_registry.go +++ b/registry/mdns_registry.go @@ -17,7 +17,7 @@ import ( "github.com/google/uuid" "github.com/micro/go-micro/v2/logger" - "github.com/micro/mdns" + "github.com/micro/go-micro/v2/util/mdns" ) var ( diff --git a/registry/mdns_watcher.go b/registry/mdns_watcher.go index 402811b9..e0ef4a48 100644 --- a/registry/mdns_watcher.go +++ b/registry/mdns_watcher.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - "github.com/micro/mdns" + "github.com/micro/go-micro/v2/util/mdns" ) type mdnsWatcher struct { diff --git a/util/mdns/.gitignore b/util/mdns/.gitignore new file mode 100644 index 00000000..83656241 --- /dev/null +++ b/util/mdns/.gitignore @@ -0,0 +1,23 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test diff --git a/util/mdns/client.go b/util/mdns/client.go new file mode 100644 index 00000000..176ebac4 --- /dev/null +++ b/util/mdns/client.go @@ -0,0 +1,501 @@ +package mdns + +import ( + "context" + "fmt" + "log" + "net" + "strings" + "sync" + "time" + + "github.com/miekg/dns" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// ServiceEntry is returned after we query for a service +type ServiceEntry struct { + Name string + Host string + AddrV4 net.IP + AddrV6 net.IP + Port int + Info string + InfoFields []string + TTL int + Type uint16 + + Addr net.IP // @Deprecated + + hasTXT bool + sent bool +} + +// complete is used to check if we have all the info we need +func (s *ServiceEntry) complete() bool { + return (s.AddrV4 != nil || s.AddrV6 != nil || s.Addr != nil) && s.Port != 0 && s.hasTXT +} + +// QueryParam is used to customize how a Lookup is performed +type QueryParam struct { + Service string // Service to lookup + Domain string // Lookup domain, default "local" + Type uint16 // Lookup type, defaults to dns.TypePTR + Context context.Context // Context + Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided + Interface *net.Interface // Multicast interface to use + Entries chan<- *ServiceEntry // Entries Channel + WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC +} + +// DefaultParams is used to return a default set of QueryParam's +func DefaultParams(service string) *QueryParam { + return &QueryParam{ + Service: service, + Domain: "local", + Timeout: time.Second, + Entries: make(chan *ServiceEntry), + WantUnicastResponse: false, // TODO(reddaly): Change this default. + } +} + +// Query looks up a given service, in a domain, waiting at most +// for a timeout before finishing the query. The results are streamed +// to a channel. Sends will not block, so clients should make sure to +// either read or buffer. +func Query(params *QueryParam) error { + // Create a new client + client, err := newClient() + if err != nil { + return err + } + defer client.Close() + + // Set the multicast interface + if params.Interface != nil { + if err := client.setInterface(params.Interface, false); err != nil { + return err + } + } + + // Ensure defaults are set + if params.Domain == "" { + params.Domain = "local" + } + + if params.Context == nil { + if params.Timeout == 0 { + params.Timeout = time.Second + } + params.Context, _ = context.WithTimeout(context.Background(), params.Timeout) + if err != nil { + return err + } + } + + // Run the query + return client.query(params) +} + +// Listen listens indefinitely for multicast updates +func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error { + // Create a new client + client, err := newClient() + if err != nil { + return err + } + defer client.Close() + + client.setInterface(nil, true) + + // Start listening for response packets + msgCh := make(chan *dns.Msg, 32) + + go client.recv(client.ipv4UnicastConn, msgCh) + go client.recv(client.ipv6UnicastConn, msgCh) + go client.recv(client.ipv4MulticastConn, msgCh) + go client.recv(client.ipv6MulticastConn, msgCh) + + ip := make(map[string]*ServiceEntry) + + for { + select { + case <-exit: + return nil + case <-client.closedCh: + return nil + case m := <-msgCh: + e := messageToEntry(m, ip) + if e == nil { + continue + } + + // Check if this entry is complete + if e.complete() { + if e.sent { + continue + } + e.sent = true + entries <- e + ip = make(map[string]*ServiceEntry) + } else { + // Fire off a node specific query + m := new(dns.Msg) + m.SetQuestion(e.Name, dns.TypePTR) + m.RecursionDesired = false + if err := client.sendQuery(m); err != nil { + log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err) + } + } + } + } + + return nil +} + +// Lookup is the same as Query, however it uses all the default parameters +func Lookup(service string, entries chan<- *ServiceEntry) error { + params := DefaultParams(service) + params.Entries = entries + return Query(params) +} + +// Client provides a query interface that can be used to +// search for service providers using mDNS +type client struct { + ipv4UnicastConn *net.UDPConn + ipv6UnicastConn *net.UDPConn + + ipv4MulticastConn *net.UDPConn + ipv6MulticastConn *net.UDPConn + + closed bool + closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used. + closeLock sync.Mutex +} + +// NewClient creates a new mdns Client that can be used to query +// for records +func newClient() (*client, error) { + // TODO(reddaly): At least attempt to bind to the port required in the spec. + // Create a IPv4 listener + uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err4 != nil && err6 != nil { + log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) + } + + if uconn4 == nil && uconn6 == nil { + return nil, fmt.Errorf("failed to bind to any unicast udp port") + } + + if uconn4 == nil { + uconn4 = &net.UDPConn{} + } + + if uconn6 == nil { + uconn6 = &net.UDPConn{} + } + + mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) + mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) + if err4 != nil && err6 != nil { + log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) + } + + if mconn4 == nil && mconn6 == nil { + return nil, fmt.Errorf("failed to bind to any multicast udp port") + } + + if mconn4 == nil { + mconn4 = &net.UDPConn{} + } + + if mconn6 == nil { + mconn6 = &net.UDPConn{} + } + + p1 := ipv4.NewPacketConn(mconn4) + p2 := ipv6.NewPacketConn(mconn6) + p1.SetMulticastLoopback(true) + p2.SetMulticastLoopback(true) + + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var errCount1, errCount2 int + + for _, iface := range ifaces { + if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + errCount1++ + } + if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + errCount2++ + } + } + + if len(ifaces) == errCount1 && len(ifaces) == errCount2 { + return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") + } + + c := &client{ + ipv4MulticastConn: mconn4, + ipv6MulticastConn: mconn6, + ipv4UnicastConn: uconn4, + ipv6UnicastConn: uconn6, + closedCh: make(chan struct{}), + } + return c, nil +} + +// Close is used to cleanup the client +func (c *client) Close() error { + c.closeLock.Lock() + defer c.closeLock.Unlock() + + if c.closed { + return nil + } + c.closed = true + + close(c.closedCh) + + if c.ipv4UnicastConn != nil { + c.ipv4UnicastConn.Close() + } + if c.ipv6UnicastConn != nil { + c.ipv6UnicastConn.Close() + } + if c.ipv4MulticastConn != nil { + c.ipv4MulticastConn.Close() + } + if c.ipv6MulticastConn != nil { + c.ipv6MulticastConn.Close() + } + + return nil +} + +// setInterface is used to set the query interface, uses sytem +// default if not provided +func (c *client) setInterface(iface *net.Interface, loopback bool) error { + p := ipv4.NewPacketConn(c.ipv4UnicastConn) + if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + return err + } + p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) + if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + return err + } + p = ipv4.NewPacketConn(c.ipv4MulticastConn) + if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + return err + } + p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) + if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + return err + } + + if loopback { + p.SetMulticastLoopback(true) + p2.SetMulticastLoopback(true) + } + + return nil +} + +// query is used to perform a lookup and stream results +func (c *client) query(params *QueryParam) error { + // Create the service name + serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) + + // Start listening for response packets + msgCh := make(chan *dns.Msg, 32) + go c.recv(c.ipv4UnicastConn, msgCh) + go c.recv(c.ipv6UnicastConn, msgCh) + go c.recv(c.ipv4MulticastConn, msgCh) + go c.recv(c.ipv6MulticastConn, msgCh) + + // Send the query + m := new(dns.Msg) + if params.Type == dns.TypeNone { + m.SetQuestion(serviceAddr, dns.TypePTR) + } else { + m.SetQuestion(serviceAddr, params.Type) + } + // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question + // Section + // + // In the Question Section of a Multicast DNS query, the top bit of the qclass + // field is used to indicate that unicast responses are preferred for this + // particular question. (See Section 5.4.) + if params.WantUnicastResponse { + m.Question[0].Qclass |= 1 << 15 + } + m.RecursionDesired = false + if err := c.sendQuery(m); err != nil { + return err + } + + // Map the in-progress responses + inprogress := make(map[string]*ServiceEntry) + + for { + select { + case resp := <-msgCh: + inp := messageToEntry(resp, inprogress) + if inp == nil { + continue + } + + // Check if this entry is complete + if inp.complete() { + if inp.sent { + continue + } + inp.sent = true + select { + case params.Entries <- inp: + case <-params.Context.Done(): + return nil + } + } else { + // Fire off a node specific query + m := new(dns.Msg) + m.SetQuestion(inp.Name, inp.Type) + m.RecursionDesired = false + if err := c.sendQuery(m); err != nil { + log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) + } + } + case <-params.Context.Done(): + return nil + } + } +} + +// sendQuery is used to multicast a query out +func (c *client) sendQuery(q *dns.Msg) error { + buf, err := q.Pack() + if err != nil { + return err + } + if c.ipv4UnicastConn != nil { + c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) + } + if c.ipv6UnicastConn != nil { + c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) + } + return nil +} + +// recv is used to receive until we get a shutdown +func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) { + if l == nil { + return + } + buf := make([]byte, 65536) + for { + c.closeLock.Lock() + if c.closed { + c.closeLock.Unlock() + return + } + c.closeLock.Unlock() + n, err := l.Read(buf) + if err != nil { + continue + } + msg := new(dns.Msg) + if err := msg.Unpack(buf[:n]); err != nil { + continue + } + select { + case msgCh <- msg: + case <-c.closedCh: + return + } + } +} + +// ensureName is used to ensure the named node is in progress +func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry { + if inp, ok := inprogress[name]; ok { + return inp + } + inp := &ServiceEntry{ + Name: name, + Type: typ, + } + inprogress[name] = inp + return inp +} + +// alias is used to setup an alias between two entries +func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) { + srcEntry := ensureName(inprogress, src, typ) + inprogress[dst] = srcEntry +} + +func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry { + var inp *ServiceEntry + + for _, answer := range append(m.Answer, m.Extra...) { + // TODO(reddaly): Check that response corresponds to serviceAddr? + switch rr := answer.(type) { + case *dns.PTR: + // Create new entry for this + inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype) + if inp.complete() { + continue + } + case *dns.SRV: + // Check for a target mismatch + if rr.Target != rr.Hdr.Name { + alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype) + } + + // Get the port + inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) + if inp.complete() { + continue + } + inp.Host = rr.Target + inp.Port = int(rr.Port) + case *dns.TXT: + // Pull out the txt + inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) + if inp.complete() { + continue + } + inp.Info = strings.Join(rr.Txt, "|") + inp.InfoFields = rr.Txt + inp.hasTXT = true + case *dns.A: + // Pull out the IP + inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) + if inp.complete() { + continue + } + inp.Addr = rr.A // @Deprecated + inp.AddrV4 = rr.A + case *dns.AAAA: + // Pull out the IP + inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) + if inp.complete() { + continue + } + inp.Addr = rr.AAAA // @Deprecated + inp.AddrV6 = rr.AAAA + } + + if inp != nil { + inp.TTL = int(answer.Header().Ttl) + } + } + + return inp +} diff --git a/util/mdns/dns_sd.go b/util/mdns/dns_sd.go new file mode 100644 index 00000000..18444c34 --- /dev/null +++ b/util/mdns/dns_sd.go @@ -0,0 +1,84 @@ +package mdns + +import "github.com/miekg/dns" + +// DNSSDService is a service that complies with the DNS-SD (RFC 6762) and MDNS +// (RFC 6762) specs for local, multicast-DNS-based discovery. +// +// DNSSDService implements the Zone interface and wraps an MDNSService instance. +// To deploy an mDNS service that is compliant with DNS-SD, it's recommended to +// register only the wrapped instance with the server. +// +// Example usage: +// service := &mdns.DNSSDService{ +// MDNSService: &mdns.MDNSService{ +// Instance: "My Foobar Service", +// Service: "_foobar._tcp", +// Port: 8000, +// } +// } +// server, err := mdns.NewServer(&mdns.Config{Zone: service}) +// if err != nil { +// log.Fatalf("Error creating server: %v", err) +// } +// defer server.Shutdown() +type DNSSDService struct { + MDNSService *MDNSService +} + +// Records returns DNS records in response to a DNS question. +// +// This function returns the DNS response of the underlying MDNSService +// instance. It also returns a PTR record for a request for " +// _services._dns-sd._udp.", as described in section 9 of RFC 6763 +// ("Service Type Enumeration"), to allow browsing of the underlying MDNSService +// instance. +func (s *DNSSDService) Records(q dns.Question) []dns.RR { + var recs []dns.RR + if q.Name == "_services._dns-sd._udp."+s.MDNSService.Domain+"." { + recs = s.dnssdMetaQueryRecords(q) + } + return append(recs, s.MDNSService.Records(q)...) +} + +// dnssdMetaQueryRecords returns the DNS records in response to a "meta-query" +// issued to browse for DNS-SD services, as per section 9. of RFC6763. +// +// A meta-query has a name of the form "_services._dns-sd._udp." where +// Domain is a fully-qualified domain, such as "local." +func (s *DNSSDService) dnssdMetaQueryRecords(q dns.Question) []dns.RR { + // Intended behavior, as described in the RFC: + // ...it may be useful for network administrators to find the list of + // advertised service types on the network, even if those Service Names + // are just opaque identifiers and not particularly informative in + // isolation. + // + // For this purpose, a special meta-query is defined. A DNS query for PTR + // records with the name "_services._dns-sd._udp." yields a set of + // PTR records, where the rdata of each PTR record is the two-abel + // name, plus the same domain, e.g., "_http._tcp.". + // Including the domain in the PTR rdata allows for slightly better name + // compression in Unicast DNS responses, but only the first two labels are + // relevant for the purposes of service type enumeration. These two-label + // service types can then be used to construct subsequent Service Instance + // Enumeration PTR queries, in this or others, to discover + // instances of that service type. + return []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Ptr: s.MDNSService.serviceAddr, + }, + } +} + +// Announcement returns DNS records that should be broadcast during the initial +// availability of the service, as described in section 8.3 of RFC 6762. +// TODO(reddaly): Add this when Announcement is added to the mdns.Zone interface. +//func (s *DNSSDService) Announcement() []dns.RR { +// return s.MDNSService.Announcement() +//} diff --git a/util/mdns/dns_sd_test.go b/util/mdns/dns_sd_test.go new file mode 100644 index 00000000..d973fd7e --- /dev/null +++ b/util/mdns/dns_sd_test.go @@ -0,0 +1,68 @@ +package mdns + +import ( + "reflect" + "testing" +) +import "github.com/miekg/dns" + +type mockMDNSService struct{} + +func (s *mockMDNSService) Records(q dns.Question) []dns.RR { + return []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{ + Name: "fakerecord", + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 42, + }, + Ptr: "fake.local.", + }, + } +} + +func (s *mockMDNSService) Announcement() []dns.RR { + return []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{ + Name: "fakeannounce", + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 42, + }, + Ptr: "fake.local.", + }, + } +} + +func TestDNSSDServiceRecords(t *testing.T) { + s := &DNSSDService{ + MDNSService: &MDNSService{ + serviceAddr: "_foobar._tcp.local.", + Domain: "local", + }, + } + q := dns.Question{ + Name: "_services._dns-sd._udp.local.", + Qtype: dns.TypePTR, + Qclass: dns.ClassINET, + } + recs := s.Records(q) + if got, want := len(recs), 1; got != want { + t.Fatalf("s.Records(%v) returned %v records, want %v", q, got, want) + } + + want := dns.RR(&dns.PTR{ + Hdr: dns.RR_Header{ + Name: "_services._dns-sd._udp.local.", + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Ptr: "_foobar._tcp.local.", + }) + if got := recs[0]; !reflect.DeepEqual(got, want) { + t.Errorf("s.Records()[0] = %v, want %v", got, want) + } +} diff --git a/util/mdns/server.go b/util/mdns/server.go new file mode 100644 index 00000000..909b39c5 --- /dev/null +++ b/util/mdns/server.go @@ -0,0 +1,476 @@ +package mdns + +import ( + "fmt" + "log" + "math/rand" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + mdnsGroupIPv4 = net.ParseIP("224.0.0.251") + mdnsGroupIPv6 = net.ParseIP("ff02::fb") + + // mDNS wildcard addresses + mdnsWildcardAddrIPv4 = &net.UDPAddr{ + IP: net.ParseIP("224.0.0.0"), + Port: 5353, + } + mdnsWildcardAddrIPv6 = &net.UDPAddr{ + IP: net.ParseIP("ff02::"), + Port: 5353, + } + + // mDNS endpoint addresses + ipv4Addr = &net.UDPAddr{ + IP: mdnsGroupIPv4, + Port: 5353, + } + ipv6Addr = &net.UDPAddr{ + IP: mdnsGroupIPv6, + Port: 5353, + } +) + +// Config is used to configure the mDNS server +type Config struct { + // Zone must be provided to support responding to queries + Zone Zone + + // Iface if provided binds the multicast listener to the given + // interface. If not provided, the system default multicase interface + // is used. + Iface *net.Interface + + // Port If it is not 0, replace the port 5353 with this port number. + Port int +} + +// mDNS server is used to listen for mDNS queries and respond if we +// have a matching local record +type Server struct { + config *Config + + ipv4List *net.UDPConn + ipv6List *net.UDPConn + + shutdown bool + shutdownCh chan struct{} + shutdownLock sync.Mutex + wg sync.WaitGroup +} + +// NewServer is used to create a new mDNS server from a config +func NewServer(config *Config) (*Server, error) { + setCustomPort(config.Port) + + // Create the listeners + // Create wildcard connections (because :5353 can be already taken by other apps) + ipv4List, _ := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) + ipv6List, _ := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) + if ipv4List == nil && ipv6List == nil { + return nil, fmt.Errorf("[ERR] mdns: Failed to bind to any udp port!") + } + + if ipv4List == nil { + ipv4List = &net.UDPConn{} + } + if ipv6List == nil { + ipv6List = &net.UDPConn{} + } + + // Join multicast groups to receive announcements + p1 := ipv4.NewPacketConn(ipv4List) + p2 := ipv6.NewPacketConn(ipv6List) + p1.SetMulticastLoopback(true) + p2.SetMulticastLoopback(true) + + if config.Iface != nil { + if err := p1.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + return nil, err + } + if err := p2.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + return nil, err + } + } else { + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + errCount1, errCount2 := 0, 0 + for _, iface := range ifaces { + if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + errCount1++ + } + if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + errCount2++ + } + } + if len(ifaces) == errCount1 && len(ifaces) == errCount2 { + return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") + } + } + + s := &Server{ + config: config, + ipv4List: ipv4List, + ipv6List: ipv6List, + shutdownCh: make(chan struct{}), + } + + go s.recv(s.ipv4List) + go s.recv(s.ipv6List) + + s.wg.Add(1) + go s.probe() + + return s, nil +} + +// Shutdown is used to shutdown the listener +func (s *Server) Shutdown() error { + s.shutdownLock.Lock() + defer s.shutdownLock.Unlock() + + if s.shutdown { + return nil + } + + s.shutdown = true + close(s.shutdownCh) + s.unregister() + + if s.ipv4List != nil { + s.ipv4List.Close() + } + if s.ipv6List != nil { + s.ipv6List.Close() + } + + s.wg.Wait() + return nil +} + +// recv is a long running routine to receive packets from an interface +func (s *Server) recv(c *net.UDPConn) { + if c == nil { + return + } + buf := make([]byte, 65536) + for { + s.shutdownLock.Lock() + if s.shutdown { + s.shutdownLock.Unlock() + return + } + s.shutdownLock.Unlock() + n, from, err := c.ReadFrom(buf) + if err != nil { + continue + } + if err := s.parsePacket(buf[:n], from); err != nil { + log.Printf("[ERR] mdns: Failed to handle query: %v", err) + } + } +} + +// parsePacket is used to parse an incoming packet +func (s *Server) parsePacket(packet []byte, from net.Addr) error { + var msg dns.Msg + if err := msg.Unpack(packet); err != nil { + log.Printf("[ERR] mdns: Failed to unpack packet: %v", err) + return err + } + // TODO: This is a bit of a hack + // We decided to ignore some mDNS answers for the time being + // See: https://tools.ietf.org/html/rfc6762#section-7.2 + msg.Truncated = false + return s.handleQuery(&msg, from) +} + +// handleQuery is used to handle an incoming query +func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { + if query.Opcode != dns.OpcodeQuery { + // "In both multicast query and multicast response messages, the OPCODE MUST + // be zero on transmission (only standard queries are currently supported + // over multicast). Multicast DNS messages received with an OPCODE other + // than zero MUST be silently ignored." Note: OpcodeQuery == 0 + return fmt.Errorf("mdns: received query with non-zero Opcode %v: %v", query.Opcode, *query) + } + if query.Rcode != 0 { + // "In both multicast query and multicast response messages, the Response + // Code MUST be zero on transmission. Multicast DNS messages received with + // non-zero Response Codes MUST be silently ignored." + return fmt.Errorf("mdns: received query with non-zero Rcode %v: %v", query.Rcode, *query) + } + + // TODO(reddaly): Handle "TC (Truncated) Bit": + // In query messages, if the TC bit is set, it means that additional + // Known-Answer records may be following shortly. A responder SHOULD + // record this fact, and wait for those additional Known-Answer records, + // before deciding whether to respond. If the TC bit is clear, it means + // that the querying host has no additional Known Answers. + if query.Truncated { + return fmt.Errorf("[ERR] mdns: support for DNS requests with high truncated bit not implemented: %v", *query) + } + + var unicastAnswer, multicastAnswer []dns.RR + + // Handle each question + for _, q := range query.Question { + mrecs, urecs := s.handleQuestion(q) + multicastAnswer = append(multicastAnswer, mrecs...) + unicastAnswer = append(unicastAnswer, urecs...) + } + + // See section 18 of RFC 6762 for rules about DNS headers. + resp := func(unicast bool) *dns.Msg { + // 18.1: ID (Query Identifier) + // 0 for multicast response, query.Id for unicast response + id := uint16(0) + if unicast { + id = query.Id + } + + var answer []dns.RR + if unicast { + answer = unicastAnswer + } else { + answer = multicastAnswer + } + if len(answer) == 0 { + return nil + } + + return &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: id, + + // 18.2: QR (Query/Response) Bit - must be set to 1 in response. + Response: true, + + // 18.3: OPCODE - must be zero in response (OpcodeQuery == 0) + Opcode: dns.OpcodeQuery, + + // 18.4: AA (Authoritative Answer) Bit - must be set to 1 + Authoritative: true, + + // The following fields must all be set to 0: + // 18.5: TC (TRUNCATED) Bit + // 18.6: RD (Recursion Desired) Bit + // 18.7: RA (Recursion Available) Bit + // 18.8: Z (Zero) Bit + // 18.9: AD (Authentic Data) Bit + // 18.10: CD (Checking Disabled) Bit + // 18.11: RCODE (Response Code) + }, + // 18.12 pertains to questions (handled by handleQuestion) + // 18.13 pertains to resource records (handled by handleQuestion) + + // 18.14: Name Compression - responses should be compressed (though see + // caveats in the RFC), so set the Compress bit (part of the dns library + // API, not part of the DNS packet) to true. + Compress: true, + + Answer: answer, + } + } + + if mresp := resp(false); mresp != nil { + if err := s.sendResponse(mresp, from); err != nil { + return fmt.Errorf("mdns: error sending multicast response: %v", err) + } + } + if uresp := resp(true); uresp != nil { + if err := s.sendResponse(uresp, from); err != nil { + return fmt.Errorf("mdns: error sending unicast response: %v", err) + } + } + return nil +} + +// handleQuestion is used to handle an incoming question +// +// The response to a question may be transmitted over multicast, unicast, or +// both. The return values are DNS records for each transmission type. +func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) { + records := s.config.Zone.Records(q) + + if len(records) == 0 { + return nil, nil + } + + // Handle unicast and multicast responses. + // TODO(reddaly): The decision about sending over unicast vs. multicast is not + // yet fully compliant with RFC 6762. For example, the unicast bit should be + // ignored if the records in question are close to TTL expiration. For now, + // we just use the unicast bit to make the decision, as per the spec: + // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question + // Section + // + // In the Question Section of a Multicast DNS query, the top bit of the + // qclass field is used to indicate that unicast responses are preferred + // for this particular question. (See Section 5.4.) + if q.Qclass&(1<<15) != 0 { + return nil, records + } + return records, nil +} + +func (s *Server) probe() { + defer s.wg.Done() + + sd, ok := s.config.Zone.(*MDNSService) + if !ok { + return + } + + name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain)) + + q := new(dns.Msg) + q.SetQuestion(name, dns.TypePTR) + q.RecursionDesired = false + + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Priority: 0, + Weight: 0, + Port: uint16(sd.Port), + Target: sd.HostName, + } + txt := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Txt: sd.TXT, + } + q.Ns = []dns.RR{srv, txt} + + randomizer := rand.New(rand.NewSource(time.Now().UnixNano())) + + for i := 0; i < 3; i++ { + if err := s.SendMulticast(q); err != nil { + log.Println("[ERR] mdns: failed to send probe:", err.Error()) + } + time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) + } + + resp := new(dns.Msg) + resp.MsgHdr.Response = true + + // set for query + q.SetQuestion(name, dns.TypeANY) + + resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...) + + // reset + q.SetQuestion(name, dns.TypePTR) + + // From RFC6762 + // The Multicast DNS responder MUST send at least two unsolicited + // responses, one second apart. To provide increased robustness against + // packet loss, a responder MAY send up to eight unsolicited responses, + // provided that the interval between unsolicited responses increases by + // at least a factor of two with every response sent. + timeout := 1 * time.Second + timer := time.NewTimer(timeout) + for i := 0; i < 3; i++ { + if err := s.SendMulticast(resp); err != nil { + log.Println("[ERR] mdns: failed to send announcement:", err.Error()) + } + select { + case <-timer.C: + timeout *= 2 + timer.Reset(timeout) + case <-s.shutdownCh: + timer.Stop() + return + } + } +} + +// multicastResponse us used to send a multicast response packet +func (s *Server) SendMulticast(msg *dns.Msg) error { + buf, err := msg.Pack() + if err != nil { + return err + } + if s.ipv4List != nil { + s.ipv4List.WriteToUDP(buf, ipv4Addr) + } + if s.ipv6List != nil { + s.ipv6List.WriteToUDP(buf, ipv6Addr) + } + return nil +} + +// sendResponse is used to send a response packet +func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error { + // TODO(reddaly): Respect the unicast argument, and allow sending responses + // over multicast. + buf, err := resp.Pack() + if err != nil { + return err + } + + // Determine the socket to send from + addr := from.(*net.UDPAddr) + if addr.IP.To4() != nil { + _, err = s.ipv4List.WriteToUDP(buf, addr) + return err + } else { + _, err = s.ipv6List.WriteToUDP(buf, addr) + return err + } +} + +func (s *Server) unregister() error { + sd, ok := s.config.Zone.(*MDNSService) + if !ok { + return nil + } + + atomic.StoreUint32(&sd.TTL, 0) + name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain)) + + q := new(dns.Msg) + q.SetQuestion(name, dns.TypeANY) + + resp := new(dns.Msg) + resp.MsgHdr.Response = true + resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...) + + return s.SendMulticast(resp) +} + +func setCustomPort(port int) { + if port != 0 { + if mdnsWildcardAddrIPv4.Port != port { + mdnsWildcardAddrIPv4.Port = port + } + if mdnsWildcardAddrIPv6.Port != port { + mdnsWildcardAddrIPv6.Port = port + } + if ipv4Addr.Port != port { + ipv4Addr.Port = port + } + if ipv6Addr.Port != port { + ipv6Addr.Port = port + } + } +} diff --git a/util/mdns/server_test.go b/util/mdns/server_test.go new file mode 100644 index 00000000..6fb00fa2 --- /dev/null +++ b/util/mdns/server_test.go @@ -0,0 +1,61 @@ +package mdns + +import ( + "testing" + "time" +) + +func TestServer_StartStop(t *testing.T) { + s := makeService(t) + serv, err := NewServer(&Config{Zone: s}) + if err != nil { + t.Fatalf("err: %v", err) + } + defer serv.Shutdown() +} + +func TestServer_Lookup(t *testing.T) { + serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp")}) + if err != nil { + t.Fatalf("err: %v", err) + } + defer serv.Shutdown() + + entries := make(chan *ServiceEntry, 1) + found := false + doneCh := make(chan struct{}) + go func() { + select { + case e := <-entries: + if e.Name != "hostname._foobar._tcp.local." { + t.Fatalf("bad: %v", e) + } + if e.Port != 80 { + t.Fatalf("bad: %v", e) + } + if e.Info != "Local web server" { + t.Fatalf("bad: %v", e) + } + found = true + + case <-time.After(80 * time.Millisecond): + t.Fatalf("timeout") + } + close(doneCh) + }() + + params := &QueryParam{ + Service: "_foobar._tcp", + Domain: "local", + Timeout: 50 * time.Millisecond, + Entries: entries, + } + err = Query(params) + if err != nil { + t.Fatalf("err: %v", err) + } + <-doneCh + if !found { + t.Fatalf("record not found") + } +} diff --git a/util/mdns/zone.go b/util/mdns/zone.go new file mode 100644 index 00000000..abbab4bd --- /dev/null +++ b/util/mdns/zone.go @@ -0,0 +1,309 @@ +package mdns + +import ( + "fmt" + "net" + "os" + "strings" + "sync/atomic" + + "github.com/miekg/dns" +) + +const ( + // defaultTTL is the default TTL value in returned DNS records in seconds. + defaultTTL = 120 +) + +// Zone is the interface used to integrate with the server and +// to serve records dynamically +type Zone interface { + // Records returns DNS records in response to a DNS question. + Records(q dns.Question) []dns.RR +} + +// MDNSService is used to export a named service by implementing a Zone +type MDNSService struct { + Instance string // Instance name (e.g. "hostService name") + Service string // Service name (e.g. "_http._tcp.") + Domain string // If blank, assumes "local" + HostName string // Host machine DNS name (e.g. "mymachine.net.") + Port int // Service Port + IPs []net.IP // IP addresses for the service's host + TXT []string // Service TXT records + TTL uint32 + serviceAddr string // Fully qualified service address + instanceAddr string // Fully qualified instance address + enumAddr string // _services._dns-sd._udp. +} + +// validateFQDN returns an error if the passed string is not a fully qualified +// hdomain name (more specifically, a hostname). +func validateFQDN(s string) error { + if len(s) == 0 { + return fmt.Errorf("FQDN must not be blank") + } + if s[len(s)-1] != '.' { + return fmt.Errorf("FQDN must end in period: %s", s) + } + // TODO(reddaly): Perform full validation. + + return nil +} + +// NewMDNSService returns a new instance of MDNSService. +// +// If domain, hostName, or ips is set to the zero value, then a default value +// will be inferred from the operating system. +// +// TODO(reddaly): This interface may need to change to account for "unique +// record" conflict rules of the mDNS protocol. Upon startup, the server should +// check to ensure that the instance name does not conflict with other instance +// names, and, if required, select a new name. There may also be conflicting +// hostName A/AAAA records. +func NewMDNSService(instance, service, domain, hostName string, port int, ips []net.IP, txt []string) (*MDNSService, error) { + // Sanity check inputs + if instance == "" { + return nil, fmt.Errorf("missing service instance name") + } + if service == "" { + return nil, fmt.Errorf("missing service name") + } + if port == 0 { + return nil, fmt.Errorf("missing service port") + } + + // Set default domain + if domain == "" { + domain = "local." + } + if err := validateFQDN(domain); err != nil { + return nil, fmt.Errorf("domain %q is not a fully-qualified domain name: %v", domain, err) + } + + // Get host information if no host is specified. + if hostName == "" { + var err error + hostName, err = os.Hostname() + if err != nil { + return nil, fmt.Errorf("could not determine host: %v", err) + } + hostName = fmt.Sprintf("%s.", hostName) + } + if err := validateFQDN(hostName); err != nil { + return nil, fmt.Errorf("hostName %q is not a fully-qualified domain name: %v", hostName, err) + } + + if len(ips) == 0 { + var err error + ips, err = net.LookupIP(trimDot(hostName)) + if err != nil { + // Try appending the host domain suffix and lookup again + // (required for Linux-based hosts) + tmpHostName := fmt.Sprintf("%s%s", hostName, domain) + + ips, err = net.LookupIP(trimDot(tmpHostName)) + + if err != nil { + return nil, fmt.Errorf("could not determine host IP addresses for %s", hostName) + } + } + } + for _, ip := range ips { + if ip.To4() == nil && ip.To16() == nil { + return nil, fmt.Errorf("invalid IP address in IPs list: %v", ip) + } + } + + return &MDNSService{ + Instance: instance, + Service: service, + Domain: domain, + HostName: hostName, + Port: port, + IPs: ips, + TXT: txt, + TTL: defaultTTL, + serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)), + instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)), + enumAddr: fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)), + }, nil +} + +// trimDot is used to trim the dots from the start or end of a string +func trimDot(s string) string { + return strings.Trim(s, ".") +} + +// Records returns DNS records in response to a DNS question. +func (m *MDNSService) Records(q dns.Question) []dns.RR { + switch q.Name { + case m.enumAddr: + return m.serviceEnum(q) + case m.serviceAddr: + return m.serviceRecords(q) + case m.instanceAddr: + return m.instanceRecords(q) + case m.HostName: + if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA { + return m.instanceRecords(q) + } + fallthrough + default: + return nil + } +} + +func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR { + switch q.Qtype { + case dns.TypeANY: + fallthrough + case dns.TypePTR: + rr := &dns.PTR{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + Ptr: m.serviceAddr, + } + return []dns.RR{rr} + default: + return nil + } +} + +// serviceRecords is called when the query matches the service name +func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR { + switch q.Qtype { + case dns.TypeANY: + fallthrough + case dns.TypePTR: + // Build a PTR response for the service + rr := &dns.PTR{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + Ptr: m.instanceAddr, + } + servRec := []dns.RR{rr} + + // Get the instance records + instRecs := m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeANY, + }) + + // Return the service record with the instance records + return append(servRec, instRecs...) + default: + return nil + } +} + +// serviceRecords is called when the query matches the instance name +func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { + switch q.Qtype { + case dns.TypeANY: + // Get the SRV, which includes A and AAAA + recs := m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeSRV, + }) + + // Add the TXT record + recs = append(recs, m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeTXT, + })...) + return recs + + case dns.TypeA: + var rr []dns.RR + for _, ip := range m.IPs { + if ip4 := ip.To4(); ip4 != nil { + rr = append(rr, &dns.A{ + Hdr: dns.RR_Header{ + Name: m.HostName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + A: ip4, + }) + } + } + return rr + + case dns.TypeAAAA: + var rr []dns.RR + for _, ip := range m.IPs { + if ip.To4() != nil { + // TODO(reddaly): IPv4 addresses could be encoded in IPv6 format and + // putinto AAAA records, but the current logic puts ipv4-encodable + // addresses into the A records exclusively. Perhaps this should be + // configurable? + continue + } + + if ip16 := ip.To16(); ip16 != nil { + rr = append(rr, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: m.HostName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + AAAA: ip16, + }) + } + } + return rr + + case dns.TypeSRV: + // Create the SRV Record + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + Priority: 10, + Weight: 1, + Port: uint16(m.Port), + Target: m.HostName, + } + recs := []dns.RR{srv} + + // Add the A record + recs = append(recs, m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeA, + })...) + + // Add the AAAA record + recs = append(recs, m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeAAAA, + })...) + return recs + + case dns.TypeTXT: + txt := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + Txt: m.TXT, + } + return []dns.RR{txt} + } + return nil +} diff --git a/util/mdns/zone_test.go b/util/mdns/zone_test.go new file mode 100644 index 00000000..082d72dd --- /dev/null +++ b/util/mdns/zone_test.go @@ -0,0 +1,275 @@ +package mdns + +import ( + "bytes" + "net" + "reflect" + "testing" + + "github.com/miekg/dns" +) + +func makeService(t *testing.T) *MDNSService { + return makeServiceWithServiceName(t, "_http._tcp") +} + +func makeServiceWithServiceName(t *testing.T, service string) *MDNSService { + m, err := NewMDNSService( + "hostname", + service, + "local.", + "testhost.", + 80, // port + []net.IP{net.IP([]byte{192, 168, 0, 42}), net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc")}, + []string{"Local web server"}) // TXT + + if err != nil { + t.Fatalf("err: %v", err) + } + + return m +} + +func TestNewMDNSService_BadParams(t *testing.T) { + for _, test := range []struct { + testName string + hostName string + domain string + }{ + { + "NewMDNSService should fail when passed hostName that is not a legal fully-qualified domain name", + "hostname", // not legal FQDN - should be "hostname." or "hostname.local.", etc. + "local.", // legal + }, + { + "NewMDNSService should fail when passed domain that is not a legal fully-qualified domain name", + "hostname.", // legal + "local", // should be "local." + }, + } { + _, err := NewMDNSService( + "instance name", + "_http._tcp", + test.domain, + test.hostName, + 80, // port + []net.IP{net.IP([]byte{192, 168, 0, 42})}, + []string{"Local web server"}) // TXT + if err == nil { + t.Fatalf("%s: error expected, but got none", test.testName) + } + } +} + +func TestMDNSService_BadAddr(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "random", + Qtype: dns.TypeANY, + } + recs := s.Records(q) + if len(recs) != 0 { + t.Fatalf("bad: %v", recs) + } +} + +func TestMDNSService_ServiceAddr(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "_http._tcp.local.", + Qtype: dns.TypeANY, + } + recs := s.Records(q) + if got, want := len(recs), 5; got != want { + t.Fatalf("got %d records, want %d: %v", got, want, recs) + } + + if ptr, ok := recs[0].(*dns.PTR); !ok { + t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs) + } else if got, want := ptr.Ptr, "hostname._http._tcp.local."; got != want { + t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want) + } + + if _, ok := recs[1].(*dns.SRV); !ok { + t.Errorf("recs[1] should be SRV record, got: %v, all reccords: %v", recs[1], recs) + } + if _, ok := recs[2].(*dns.A); !ok { + t.Errorf("recs[2] should be A record, got: %v, all records: %v", recs[2], recs) + } + if _, ok := recs[3].(*dns.AAAA); !ok { + t.Errorf("recs[3] should be AAAA record, got: %v, all records: %v", recs[3], recs) + } + if _, ok := recs[4].(*dns.TXT); !ok { + t.Errorf("recs[4] should be TXT record, got: %v, all records: %v", recs[4], recs) + } + + q.Qtype = dns.TypePTR + if recs2 := s.Records(q); !reflect.DeepEqual(recs, recs2) { + t.Fatalf("PTR question should return same result as ANY question: ANY => %v, PTR => %v", recs, recs2) + } +} + +func TestMDNSService_InstanceAddr_ANY(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeANY, + } + recs := s.Records(q) + if len(recs) != 4 { + t.Fatalf("bad: %v", recs) + } + if _, ok := recs[0].(*dns.SRV); !ok { + t.Fatalf("bad: %v", recs[0]) + } + if _, ok := recs[1].(*dns.A); !ok { + t.Fatalf("bad: %v", recs[1]) + } + if _, ok := recs[2].(*dns.AAAA); !ok { + t.Fatalf("bad: %v", recs[2]) + } + if _, ok := recs[3].(*dns.TXT); !ok { + t.Fatalf("bad: %v", recs[3]) + } +} + +func TestMDNSService_InstanceAddr_SRV(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeSRV, + } + recs := s.Records(q) + if len(recs) != 3 { + t.Fatalf("bad: %v", recs) + } + srv, ok := recs[0].(*dns.SRV) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if _, ok := recs[1].(*dns.A); !ok { + t.Fatalf("bad: %v", recs[1]) + } + if _, ok := recs[2].(*dns.AAAA); !ok { + t.Fatalf("bad: %v", recs[2]) + } + + if srv.Port != uint16(s.Port) { + t.Fatalf("bad: %v", recs[0]) + } +} + +func TestMDNSService_InstanceAddr_A(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeA, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + a, ok := recs[0].(*dns.A) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if !bytes.Equal(a.A, []byte{192, 168, 0, 42}) { + t.Fatalf("bad: %v", recs[0]) + } +} + +func TestMDNSService_InstanceAddr_AAAA(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeAAAA, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + a4, ok := recs[0].(*dns.AAAA) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + ip6 := net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc") + if got := len(ip6); got != net.IPv6len { + t.Fatalf("test IP failed to parse (len = %d, want %d)", got, net.IPv6len) + } + if !bytes.Equal(a4.AAAA, ip6) { + t.Fatalf("bad: %v", recs[0]) + } +} + +func TestMDNSService_InstanceAddr_TXT(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeTXT, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + txt, ok := recs[0].(*dns.TXT) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if got, want := txt.Txt, s.TXT; !reflect.DeepEqual(got, want) { + t.Fatalf("TXT record mismatch for %v: got %v, want %v", recs[0], got, want) + } +} + +func TestMDNSService_HostNameQuery(t *testing.T) { + s := makeService(t) + for _, test := range []struct { + q dns.Question + want []dns.RR + }{ + { + dns.Question{Name: "testhost.", Qtype: dns.TypeA}, + []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "testhost.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 120, + }, + A: net.IP([]byte{192, 168, 0, 42}), + }}, + }, + { + dns.Question{Name: "testhost.", Qtype: dns.TypeAAAA}, + []dns.RR{&dns.AAAA{ + Hdr: dns.RR_Header{ + Name: "testhost.", + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 120, + }, + AAAA: net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc"), + }}, + }, + } { + if got := s.Records(test.q); !reflect.DeepEqual(got, test.want) { + t.Errorf("hostname query failed: s.Records(%v) = %v, want %v", test.q, got, test.want) + } + } +} + +func TestMDNSService_serviceEnum_PTR(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "_services._dns-sd._udp.local.", + Qtype: dns.TypePTR, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + if ptr, ok := recs[0].(*dns.PTR); !ok { + t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs) + } else if got, want := ptr.Ptr, "_http._tcp.local."; got != want { + t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want) + } +}