mirror of
https://github.com/go-kratos/kratos.git
synced 2026-05-22 10:15:24 +02:00
feat:add stream interceptor use ctx encapsulation (#1770)
* feat:add stream interceptor use ctx encapsulation * add reply header
This commit is contained in:
@@ -0,0 +1,83 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
ic "github.com/go-kratos/kratos/v2/internal/context"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
"google.golang.org/grpc"
|
||||
grpcmd "google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// unaryServerInterceptor is a gRPC unary server interceptor
|
||||
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
ctx, cancel := ic.Merge(ctx, s.baseCtx)
|
||||
defer cancel()
|
||||
md, _ := grpcmd.FromIncomingContext(ctx)
|
||||
replyHeader := grpcmd.MD{}
|
||||
ctx = transport.NewServerContext(ctx, &Transport{
|
||||
endpoint: s.endpoint.String(),
|
||||
operation: info.FullMethod,
|
||||
reqHeader: headerCarrier(md),
|
||||
replyHeader: headerCarrier(replyHeader),
|
||||
})
|
||||
if s.timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, s.timeout)
|
||||
defer cancel()
|
||||
}
|
||||
h := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
if len(s.middleware) > 0 {
|
||||
h = middleware.Chain(s.middleware...)(h)
|
||||
}
|
||||
reply, err := h(ctx, req)
|
||||
if len(replyHeader) > 0 {
|
||||
_ = grpc.SetHeader(ctx, replyHeader)
|
||||
}
|
||||
return reply, err
|
||||
}
|
||||
}
|
||||
|
||||
// wrappedStream is rewrite grpc stream's context
|
||||
type wrappedStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream {
|
||||
return &wrappedStream{
|
||||
ServerStream: stream,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wrappedStream) Context() context.Context {
|
||||
return w.ctx
|
||||
}
|
||||
|
||||
// streamServerInterceptor is a gRPC stream server interceptor
|
||||
func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
ctx, cancel := ic.Merge(ss.Context(), s.baseCtx)
|
||||
defer cancel()
|
||||
md, _ := grpcmd.FromIncomingContext(ctx)
|
||||
replyHeader := grpcmd.MD{}
|
||||
ctx = transport.NewServerContext(ctx, &Transport{
|
||||
endpoint: s.endpoint.String(),
|
||||
operation: info.FullMethod,
|
||||
reqHeader: headerCarrier(md),
|
||||
replyHeader: headerCarrier(replyHeader),
|
||||
})
|
||||
|
||||
ws := NewWrappedStream(ctx, ss)
|
||||
|
||||
err := handler(srv, ws)
|
||||
if len(replyHeader) > 0 {
|
||||
_ = grpc.SetHeader(ctx, replyHeader)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
+23
-38
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/go-kratos/kratos/v2/internal/endpoint"
|
||||
|
||||
apimd "github.com/go-kratos/kratos/v2/api/metadata"
|
||||
ic "github.com/go-kratos/kratos/v2/internal/context"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/internal/host"
|
||||
"github.com/go-kratos/kratos/v2/log"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/health"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
grpcmd "google.golang.org/grpc/metadata"
|
||||
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
@@ -84,7 +84,14 @@ func Listener(lis net.Listener) ServerOption {
|
||||
// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the server.
|
||||
func UnaryInterceptor(in ...grpc.UnaryServerInterceptor) ServerOption {
|
||||
return func(s *Server) {
|
||||
s.ints = in
|
||||
s.unaryInts = in
|
||||
}
|
||||
}
|
||||
|
||||
// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the server.
|
||||
func StreamInterceptor(in ...grpc.StreamServerInterceptor) ServerOption {
|
||||
return func(s *Server) {
|
||||
s.streamInts = in
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,7 +115,8 @@ type Server struct {
|
||||
timeout time.Duration
|
||||
log *log.Helper
|
||||
middleware []middleware.Middleware
|
||||
ints []grpc.UnaryServerInterceptor
|
||||
unaryInts []grpc.UnaryServerInterceptor
|
||||
streamInts []grpc.StreamServerInterceptor
|
||||
grpcOpts []grpc.ServerOption
|
||||
health *health.Server
|
||||
metadata *apimd.Server
|
||||
@@ -127,14 +135,21 @@ func NewServer(opts ...ServerOption) *Server {
|
||||
for _, o := range opts {
|
||||
o(srv)
|
||||
}
|
||||
ints := []grpc.UnaryServerInterceptor{
|
||||
unaryInts := []grpc.UnaryServerInterceptor{
|
||||
srv.unaryServerInterceptor(),
|
||||
}
|
||||
if len(srv.ints) > 0 {
|
||||
ints = append(ints, srv.ints...)
|
||||
streamInts := []grpc.StreamServerInterceptor{
|
||||
srv.streamServerInterceptor(),
|
||||
}
|
||||
if len(srv.unaryInts) > 0 {
|
||||
unaryInts = append(unaryInts, srv.unaryInts...)
|
||||
}
|
||||
if len(srv.streamInts) > 0 {
|
||||
streamInts = append(streamInts, srv.streamInts...)
|
||||
}
|
||||
grpcOpts := []grpc.ServerOption{
|
||||
grpc.ChainUnaryInterceptor(ints...),
|
||||
grpc.ChainUnaryInterceptor(unaryInts...),
|
||||
grpc.ChainStreamInterceptor(streamInts...),
|
||||
}
|
||||
if srv.tlsConf != nil {
|
||||
grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewTLS(srv.tlsConf)))
|
||||
@@ -198,33 +213,3 @@ func (s *Server) listenAndEndpoint() error {
|
||||
s.endpoint = endpoint.NewEndpoint("grpc", addr, s.tlsConf != nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
ctx, cancel := ic.Merge(ctx, s.baseCtx)
|
||||
defer cancel()
|
||||
md, _ := grpcmd.FromIncomingContext(ctx)
|
||||
replyHeader := grpcmd.MD{}
|
||||
ctx = transport.NewServerContext(ctx, &Transport{
|
||||
endpoint: s.endpoint.String(),
|
||||
operation: info.FullMethod,
|
||||
reqHeader: headerCarrier(md),
|
||||
replyHeader: headerCarrier(replyHeader),
|
||||
})
|
||||
if s.timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, s.timeout)
|
||||
defer cancel()
|
||||
}
|
||||
h := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
if len(s.middleware) > 0 {
|
||||
h = middleware.Chain(s.middleware...)(h)
|
||||
}
|
||||
reply, err := h(ctx, req)
|
||||
if len(replyHeader) > 0 {
|
||||
_ = grpc.SetHeader(ctx, replyHeader)
|
||||
}
|
||||
return reply, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,8 +207,24 @@ func TestUnaryInterceptor(t *testing.T) {
|
||||
},
|
||||
}
|
||||
UnaryInterceptor(v...)(o)
|
||||
if !reflect.DeepEqual(v, o.ints) {
|
||||
t.Errorf("expect %v, got %v", v, o.ints)
|
||||
if !reflect.DeepEqual(v, o.unaryInts) {
|
||||
t.Errorf("expect %v, got %v", v, o.unaryInts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStramInterceptor(t *testing.T) {
|
||||
o := &Server{}
|
||||
v := []grpc.StreamServerInterceptor{
|
||||
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
return nil
|
||||
},
|
||||
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
StreamInterceptor(v...)(o)
|
||||
if !reflect.DeepEqual(v, o.streamInts) {
|
||||
t.Errorf("expect %v, got %v", v, o.streamInts)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user