1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-28 03:57:02 +02:00
kratos/transport/grpc/interceptor.go
haiyux 89583885e4
feat:add stream interceptor use ctx encapsulation (#1770)
* feat:add stream interceptor use ctx encapsulation

* add reply header
2022-01-17 19:58:27 +08:00

84 lines
2.3 KiB
Go

package grpc
import (
"context"
ic "github.com/go-kratos/kratos/v2/internal/context"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"google.golang.org/grpc"
grpcmd "google.golang.org/grpc/metadata"
)
// unaryServerInterceptor is a gRPC unary server interceptor
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ctx, cancel := ic.Merge(ctx, s.baseCtx)
defer cancel()
md, _ := grpcmd.FromIncomingContext(ctx)
replyHeader := grpcmd.MD{}
ctx = transport.NewServerContext(ctx, &Transport{
endpoint: s.endpoint.String(),
operation: info.FullMethod,
reqHeader: headerCarrier(md),
replyHeader: headerCarrier(replyHeader),
})
if s.timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, s.timeout)
defer cancel()
}
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return handler(ctx, req)
}
if len(s.middleware) > 0 {
h = middleware.Chain(s.middleware...)(h)
}
reply, err := h(ctx, req)
if len(replyHeader) > 0 {
_ = grpc.SetHeader(ctx, replyHeader)
}
return reply, err
}
}
// wrappedStream is rewrite grpc stream's context
type wrappedStream struct {
grpc.ServerStream
ctx context.Context
}
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream {
return &wrappedStream{
ServerStream: stream,
ctx: ctx,
}
}
func (w *wrappedStream) Context() context.Context {
return w.ctx
}
// streamServerInterceptor is a gRPC stream server interceptor
func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx, cancel := ic.Merge(ss.Context(), s.baseCtx)
defer cancel()
md, _ := grpcmd.FromIncomingContext(ctx)
replyHeader := grpcmd.MD{}
ctx = transport.NewServerContext(ctx, &Transport{
endpoint: s.endpoint.String(),
operation: info.FullMethod,
reqHeader: headerCarrier(md),
replyHeader: headerCarrier(replyHeader),
})
ws := NewWrappedStream(ctx, ss)
err := handler(srv, ws)
if len(replyHeader) > 0 {
_ = grpc.SetHeader(ctx, replyHeader)
}
return err
}
}