diff --git a/network/router/default.go b/network/router/default.go index cf98da3c..0b8c41de 100644 --- a/network/router/default.go +++ b/network/router/default.go @@ -445,7 +445,6 @@ func (r *router) watchErrors() { } // Run runs the router. -// It returns error if the router is already running. func (r *router) run() { r.Lock() defer r.Unlock() diff --git a/network/router/service/service.go b/network/router/service/service.go index f651131a..5a6daeda 100644 --- a/network/router/service/service.go +++ b/network/router/service/service.go @@ -2,7 +2,6 @@ package service import ( "context" - "errors" "fmt" "io" "sync" @@ -14,11 +13,6 @@ import ( pb "github.com/micro/go-micro/network/router/proto" ) -var ( - // ErrNotImplemented means the functionality has not been implemented - ErrNotImplemented = errors.New("not implemented") -) - type svc struct { opts router.Options router pb.RouterService @@ -71,31 +65,10 @@ func (s *svc) Options() router.Options { return s.opts } -// watchErrors watches router errors and takes appropriate actions -func (s *svc) watchErrors() { - var err error - - select { - case <-s.exit: - case err = <-s.errChan: - } - - s.Lock() - defer s.Unlock() - if s.status.Code != router.Stopped { - // notify all goroutines to finish - close(s.exit) - // TODO" might need to drain some channels here - } - - if err != nil { - s.status = router.Status{Code: router.Error, Error: err} - } -} - // watchRouter watches router and send events to all registered watchers func (s *svc) watchRouter(stream pb.Router_WatchService) error { defer stream.Close() + var watchErr error for { @@ -122,6 +95,7 @@ func (s *svc) watchRouter(stream pb.Router_WatchService) error { Route: route, } + // TODO: might make this non-blocking s.RLock() for _, w := range s.watchers { select { @@ -135,8 +109,31 @@ func (s *svc) watchRouter(stream pb.Router_WatchService) error { return watchErr } +// watchErrors watches router errors and takes appropriate actions +func (s *svc) watchErrors() { + var err error + + select { + case <-s.exit: + case err = <-s.errChan: + } + + s.Lock() + defer s.Unlock() + if s.status.Code != router.Stopped { + // notify all goroutines to finish + close(s.exit) + // drain the advertise channel + for range s.advertChan { + } + } + + if err != nil { + s.status = router.Status{Code: router.Error, Error: err} + } +} + // Run runs the router. -// It returns error if the router is already running. func (s *svc) run() { s.Lock() defer s.Unlock() @@ -145,7 +142,7 @@ func (s *svc) run() { case router.Stopped, router.Error: stream, err := s.router.Watch(context.Background(), &pb.WatchRequest{}) if err != nil { - s.status = router.Status{Code: router.Error, Error: fmt.Errorf("failed getting router stream: %s", err)} + s.status = router.Status{Code: router.Error, Error: fmt.Errorf("failed getting event stream: %s", err)} return } @@ -425,6 +422,14 @@ func (s *svc) Watch(opts ...router.WatchOption) (router.Watcher, error) { s.watchers[uuid.New().String()] = w s.Unlock() + // when the router stops, stop the watcher and exit + s.wg.Add(1) + go func() { + defer s.wg.Done() + <-s.exit + w.Stop() + }() + return w, nil } @@ -446,8 +451,9 @@ func (s *svc) Stop() error { if s.status.Code == router.Running || s.status.Code == router.Advertising { // notify all goroutines to finish close(s.exit) - // TODO: might need to drain some channels here - + // drain the advertise channel + for range s.advertChan { + } // mark the router as Stopped and set its Error to nil s.status = router.Status{Code: router.Stopped, Error: nil} }