1
0
mirror of https://github.com/go-micro/go-micro.git synced 2024-12-12 08:23:58 +02:00

Merge pull request #488 from magodo/wait_accept_custom_wg

`Wait()` option now accept *sync.WaitGroup
This commit is contained in:
Asim Aslam 2019-05-27 15:02:09 +01:00 committed by GitHub
commit 58a70562d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 15 deletions

View File

@ -54,7 +54,7 @@ func newFunction(opts ...Option) Function {
service.Server().Init( service.Server().Init(
// ensure the service waits for requests to finish // ensure the service waits for requests to finish
server.Wait(true), server.Wait(nil),
// wrap handlers and subscribers to finish execution // wrap handlers and subscribers to finish execution
server.WrapHandler(fnHandlerWrapper(fn)), server.WrapHandler(fnHandlerWrapper(fn)),
server.WrapSubscriber(fnSubWrapper(fn)), server.WrapSubscriber(fnSubWrapper(fn)),

View File

@ -2,19 +2,20 @@ package server
import ( import (
"context" "context"
"sync"
) )
type serverKey struct{} type serverKey struct{}
func wait(ctx context.Context) bool { func wait(ctx context.Context) *sync.WaitGroup {
if ctx == nil { if ctx == nil {
return false return nil
} }
wait, ok := ctx.Value("wait").(bool) wg, ok := ctx.Value("wait").(*sync.WaitGroup)
if !ok { if !ok {
return false return nil
} }
return wait return wg
} }
func FromContext(ctx context.Context) (Server, bool) { func FromContext(ctx context.Context) (Server, bool) {

View File

@ -2,6 +2,7 @@ package server
import ( import (
"context" "context"
"sync"
"time" "time"
"github.com/micro/go-micro/broker" "github.com/micro/go-micro/broker"
@ -198,12 +199,18 @@ func WithRouter(r Router) Option {
} }
// Wait tells the server to wait for requests to finish before exiting // Wait tells the server to wait for requests to finish before exiting
func Wait(b bool) Option { // If `wg` is nil, server only wait for completion of rpc handler.
// For user need finer grained control, pass a concrete `wg` here, server will
// wait against it on stop.
func Wait(wg *sync.WaitGroup) Option {
return func(o *Options) { return func(o *Options) {
if o.Context == nil { if o.Context == nil {
o.Context = context.Background() o.Context = context.Background()
} }
o.Context = context.WithValue(o.Context, "wait", b) if wg == nil {
wg = new(sync.WaitGroup)
}
o.Context = context.WithValue(o.Context, "wait", wg)
} }
} }

View File

@ -31,7 +31,7 @@ type rpcServer struct {
// used for first registration // used for first registration
registered bool registered bool
// graceful exit // graceful exit
wg sync.WaitGroup wg *sync.WaitGroup
} }
func newRpcServer(opts ...Option) Server { func newRpcServer(opts ...Option) Server {
@ -42,6 +42,7 @@ func newRpcServer(opts ...Option) Server {
handlers: make(map[string]Handler), handlers: make(map[string]Handler),
subscribers: make(map[*subscriber][]broker.Subscriber), subscribers: make(map[*subscriber][]broker.Subscriber),
exit: make(chan chan error), exit: make(chan chan error),
wg: wait(options.Context),
} }
} }
@ -63,8 +64,10 @@ func (s *rpcServer) ServeConn(sock transport.Socket) {
return return
} }
// add to wait group // add to wait group if "wait" is opt-in
if s.wg != nil {
s.wg.Add(1) s.wg.Add(1)
}
// we use this Timeout header to set a server deadline // we use this Timeout header to set a server deadline
to := msg.Header["Timeout"] to := msg.Header["Timeout"]
@ -111,7 +114,9 @@ func (s *rpcServer) ServeConn(sock transport.Socket) {
}, },
Body: []byte(err.Error()), Body: []byte(err.Error()),
}) })
if s.wg != nil {
s.wg.Done() s.wg.Done()
}
return return
} }
} }
@ -167,14 +172,18 @@ func (s *rpcServer) ServeConn(sock transport.Socket) {
if err != nil { if err != nil {
log.Logf("rpc: unable to write error response: %v", err) log.Logf("rpc: unable to write error response: %v", err)
} }
if s.wg != nil {
s.wg.Done() s.wg.Done()
}
return return
} }
// done // done
if s.wg != nil {
s.wg.Done() s.wg.Done()
} }
} }
}
func (s *rpcServer) newCodec(contentType string) (codec.NewCodec, error) { func (s *rpcServer) newCodec(contentType string) (codec.NewCodec, error) {
if cf, ok := s.opts.Codecs[contentType]; ok { if cf, ok := s.opts.Codecs[contentType]; ok {
@ -555,7 +564,7 @@ func (s *rpcServer) Start() error {
} }
// wait for requests to finish // wait for requests to finish
if wait(s.opts.Context) { if s.wg != nil {
s.wg.Wait() s.wg.Wait()
} }