1
0
mirror of https://github.com/go-micro/go-micro.git synced 2024-12-24 10:07:04 +02:00

update mdns to remove race condition

This commit is contained in:
Asim Aslam 2019-02-01 13:41:11 +00:00
parent 652b1067f5
commit 88e12347d0

View File

@ -2,6 +2,7 @@
package registry
import (
"context"
"net"
"strings"
"sync"
@ -194,20 +195,20 @@ func (m *mdnsRegistry) Deregister(service *Service) error {
}
func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
p := mdns.DefaultParams(service)
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]*Service)
entries := make(chan *mdns.ServiceEntry, 10)
done := make(chan bool)
p := mdns.DefaultParams(service)
// set context with timeout
p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout)
// set entries channel
p.Entries = entries
go func() {
for {
select {
case e := <-entryCh:
case e := <-entries:
// list record so skip
if p.Service == "_services" {
continue
@ -243,16 +244,21 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
})
serviceMap[txt.Version] = s
case <-exit:
case <-p.Context.Done():
close(done)
return
}
}
}()
// execute the query
if err := mdns.Query(p); err != nil {
return nil, err
}
// wait for completion
<-done
// create list and return
var services []*Service
@ -264,21 +270,22 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
}
func (m *mdnsRegistry) ListServices() ([]*Service, error) {
p := mdns.DefaultParams("_services")
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]bool)
entries := make(chan *mdns.ServiceEntry, 10)
done := make(chan bool)
p := mdns.DefaultParams("_services")
// set context with timeout
p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout)
// set entries channel
p.Entries = entries
var services []*Service
go func() {
for {
select {
case e := <-entryCh:
case e := <-entries:
if e.TTL == 0 {
continue
}
@ -288,16 +295,21 @@ func (m *mdnsRegistry) ListServices() ([]*Service, error) {
serviceMap[name] = true
services = append(services, &Service{Name: name})
}
case <-exit:
case <-p.Context.Done():
close(done)
return
}
}
}()
// execute query
if err := mdns.Query(p); err != nil {
return nil, err
}
// wait till done
<-done
return services, nil
}