diff --git a/server/context.go b/server/context.go index 88d19257..c7b08b93 100644 --- a/server/context.go +++ b/server/context.go @@ -6,6 +6,14 @@ import ( type serverKey struct{} +func wait(ctx context.Context) bool { + if ctx == nil { + return false + } + wait, _ := ctx.Value("wait").(bool) + return wait +} + func FromContext(ctx context.Context) (Server, bool) { c, ok := ctx.Value(serverKey{}).(Server) return c, ok diff --git a/server/rpc_server.go b/server/rpc_server.go index 6f97ab9c..48fd97c1 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -30,6 +30,8 @@ type rpcServer struct { subscribers map[*subscriber][]broker.Subscriber // used for first registration registered bool + // graceful exit + wg sync.WaitGroup } func newRpcServer(opts ...Option) Server { @@ -44,6 +46,7 @@ func newRpcServer(opts ...Option) Server { handlers: make(map[string]Handler), subscribers: make(map[*subscriber][]broker.Subscriber), exit: make(chan chan error), + wg: sync.WaitGroup{}, } } @@ -100,11 +103,18 @@ func (s *rpcServer) accept(sock transport.Socket) { } } + // add to wait group + s.wg.Add(1) + // TODO: needs better error handling if err := s.rpc.serveRequest(ctx, codec, ct); err != nil { log.Logf("Unexpected error serving request, closing socket: %v", err) + s.wg.Done() return } + + // finish request + s.wg.Done() } } @@ -371,8 +381,18 @@ func (s *rpcServer) Start() error { go ts.Accept(s.accept) go func() { + // wait for exit ch := <-s.exit + + // wait for requests to finish + if wait(s.opts.Context) { + s.wg.Wait() + } + + // close transport listener ch <- ts.Close() + + // disconnect the broker config.Broker.Disconnect() }()