diff --git a/network/router/default_router.go b/network/router/default_router.go index 91f9e2a2..5c7b73ae 100644 --- a/network/router/default_router.go +++ b/network/router/default_router.go @@ -134,12 +134,10 @@ func (r *router) manageServiceRoutes(w registry.Watcher, metric int) error { for { res, err := w.Next() - if err == registry.ErrWatcherStopped { - break - } - if err != nil { - watchErr = err + if err != registry.ErrWatcherStopped { + watchErr = err + } break } @@ -181,14 +179,13 @@ func (r *router) watchTable(w Watcher) error { var watchErr error +exit: for { event, err := w.Next() - if err == ErrWatcherStopped { - break - } - if err != nil { - watchErr = err + if err != ErrWatcherStopped { + watchErr = err + } break } @@ -200,23 +197,21 @@ func (r *router) watchTable(w Watcher) error { select { case <-r.exit: - return nil + break exit case r.advertChan <- u: } } + // close the advertisement channel + close(r.advertChan) + return watchErr } -// manageStatus manages router status -func (r *router) manageStatus(errChan <-chan error) { +// watchError watches router errors +func (r *router) watchError(errChan <-chan error) { defer r.wg.Done() - r.Lock() - r.status.Code = Running - r.status.Error = nil - r.Unlock() - var code StatusCode var err error @@ -229,14 +224,13 @@ func (r *router) manageStatus(errChan <-chan error) { r.Lock() defer r.Unlock() - r.status.Code = code - r.status.Error = err - - // close the advertise channel - close(r.advertChan) + status := Status{ + Code: code, + Error: err, + } + r.status = status // stop the router if some error happened - // this will notify all watcher goroutines to stop if err != nil && code != Stopped { close(r.exit) } @@ -303,7 +297,14 @@ func (r *router) Advertise() (<-chan *Update, error) { }() r.wg.Add(1) - go r.manageStatus(errChan) + go r.watchError(errChan) + + // mark router as running and set its Error to nil + status := Status{ + Code: Running, + Error: nil, + } + r.status = status } return r.advertChan, nil @@ -338,15 +339,19 @@ func (r *router) Status() Status { // Stop stops the router func (r *router) Stop() error { r.RLock() - defer r.RUnlock() - // only close the channel if the router is running if r.status.Code == Running { // notify all goroutines to finish close(r.exit) - // wait for all goroutines to finish - r.wg.Wait() } + r.RUnlock() + + // drain the advertise channel + for range r.advertChan { + } + + // wait for all goroutines to finish + r.wg.Wait() return nil } diff --git a/network/router/table_watcher.go b/network/router/table_watcher.go index 4e9d9a9e..91411247 100644 --- a/network/router/table_watcher.go +++ b/network/router/table_watcher.go @@ -9,7 +9,7 @@ import ( var ( // ErrWatcherStopped is returned when routing table watcher has been stopped - ErrWatcherStopped = errors.New("routing table watcher stopped") + ErrWatcherStopped = errors.New("watcher stopped") ) // EventType defines routing table event diff --git a/registry/mdns_watcher.go b/registry/mdns_watcher.go index 7ccb6e80..bbcf90ea 100644 --- a/registry/mdns_watcher.go +++ b/registry/mdns_watcher.go @@ -1,7 +1,6 @@ package registry import ( - "errors" "strings" "github.com/micro/mdns" @@ -63,7 +62,7 @@ func (m *mdnsWatcher) Next() (*Result, error) { Service: service, }, nil case <-m.exit: - return nil, errors.New("watcher stopped") + return nil, ErrWatcherStopped } } }