diff --git a/client/grpc/grpc.go b/client/grpc/grpc.go index 439be459..fd5bc853 100644 --- a/client/grpc/grpc.go +++ b/client/grpc/grpc.go @@ -110,12 +110,21 @@ func (g *grpcClient) call(ctx context.Context, node *registry.Node, req client.R var grr error - cc, err := g.pool.getConn(address, grpc.WithDefaultCallOptions(grpc.ForceCodec(cf)), - grpc.WithTimeout(opts.DialTimeout), g.secure(), + grpcDialOptions := []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.ForceCodec(cf)), + grpc.WithTimeout(opts.DialTimeout), + g.secure(), grpc.WithDefaultCallOptions( grpc.MaxCallRecvMsgSize(maxRecvMsgSize), grpc.MaxCallSendMsgSize(maxSendMsgSize), - )) + ), + } + + if opts := g.getGrpcDialOptions(); opts != nil { + grpcDialOptions = append(grpcDialOptions, opts...) + } + + cc, err := g.pool.getConn(address, grpcDialOptions...) if err != nil { return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) } @@ -127,7 +136,11 @@ func (g *grpcClient) call(ctx context.Context, node *registry.Node, req client.R ch := make(chan error, 1) go func() { - err := cc.Invoke(ctx, methodToGRPC(req.Service(), req.Endpoint()), req.Body(), rsp, grpc.CallContentSubtype(cf.Name())) + grpcCallOptions := []grpc.CallOption{grpc.CallContentSubtype(cf.Name())} + if opts := g.getGrpcCallOptions(); opts != nil { + grpcCallOptions = append(grpcCallOptions, opts...) + } + err := cc.Invoke(ctx, methodToGRPC(req.Service(), req.Endpoint()), req.Body(), rsp, grpcCallOptions...) ch <- microError(err) }() @@ -175,7 +188,16 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client wc := wrapCodec{cf} - cc, err := grpc.DialContext(dialCtx, address, grpc.WithDefaultCallOptions(grpc.ForceCodec(wc)), g.secure()) + grpcDialOptions := []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.ForceCodec(wc)), + g.secure(), + } + + if opts := g.getGrpcDialOptions(); opts != nil { + grpcDialOptions = append(grpcDialOptions, opts...) + } + + cc, err := grpc.DialContext(dialCtx, address, grpcDialOptions...) if err != nil { return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) } @@ -186,7 +208,11 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client ServerStreams: true, } - st, err := cc.NewStream(ctx, desc, methodToGRPC(req.Service(), req.Endpoint())) + grpcCallOptions := []grpc.CallOption{} + if opts := g.getGrpcCallOptions(); opts != nil { + grpcCallOptions = append(grpcCallOptions, opts...) + } + st, err := cc.NewStream(ctx, desc, methodToGRPC(req.Service(), req.Endpoint()), grpcCallOptions...) if err != nil { return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error creating stream: %v", err)) } @@ -514,6 +540,46 @@ func (g *grpcClient) String() string { return "grpc" } +func (g *grpcClient) getGrpcDialOptions() []grpc.DialOption { + if g.opts.CallOptions.Context == nil { + return nil + } + + v := g.opts.CallOptions.Context.Value(grpcDialOptions{}) + + if v == nil { + return nil + } + + opts, ok := v.([]grpc.DialOption) + + if !ok { + return nil + } + + return opts +} + +func (g *grpcClient) getGrpcCallOptions() []grpc.CallOption { + if g.opts.CallOptions.Context == nil { + return nil + } + + v := g.opts.CallOptions.Context.Value(grpcCallOptions{}) + + if v == nil { + return nil + } + + opts, ok := v.([]grpc.CallOption) + + if !ok { + return nil + } + + return opts +} + func newClient(opts ...client.Option) client.Client { options := client.Options{ Codecs: make(map[string]codec.NewCodec), diff --git a/client/grpc/options.go b/client/grpc/options.go index c702ade3..e7f2fceb 100644 --- a/client/grpc/options.go +++ b/client/grpc/options.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "github.com/micro/go-micro/client" + "google.golang.org/grpc" "google.golang.org/grpc/encoding" ) @@ -23,6 +24,8 @@ type codecsKey struct{} type tlsAuth struct{} type maxRecvMsgSizeKey struct{} type maxSendMsgSizeKey struct{} +type grpcDialOptions struct{} +type grpcCallOptions struct{} // gRPC Codec to be used to encode/decode requests for a given content type func Codec(contentType string, c encoding.Codec) client.Option { @@ -72,3 +75,27 @@ func MaxSendMsgSize(s int) client.Option { o.Context = context.WithValue(o.Context, maxSendMsgSizeKey{}, s) } } + +// +// DialOptions to be used to configure gRPC dial options +// +func DialOptions(opts ...grpc.DialOption) client.CallOption { + return func(o *client.CallOptions) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, grpcDialOptions{}, opts) + } +} + +// +// CallOptions to be used to configure gRPC call options +// +func CallOptions(opts ...grpc.CallOption) client.CallOption { + return func(o *client.CallOptions) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, grpcCallOptions{}, opts) + } +}