mirror of
https://github.com/go-kratos/kratos.git
synced 2025-01-12 02:28:05 +02:00
feat: adding stream interceptor for logging middleware (#3359)
This commit is contained in:
parent
908e6256a9
commit
e1f5dc42b1
@ -11,6 +11,7 @@ import (
|
||||
grpcinsecure "google.golang.org/grpc/credentials/insecure"
|
||||
grpcmd "google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/internal/matcher"
|
||||
"github.com/go-kratos/kratos/v2/log"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/registry"
|
||||
@ -132,6 +133,7 @@ type clientOptions struct {
|
||||
timeout time.Duration
|
||||
discovery registry.Discovery
|
||||
middleware []middleware.Middleware
|
||||
streamMiddleware []middleware.Middleware
|
||||
ints []grpc.UnaryClientInterceptor
|
||||
streamInts []grpc.StreamClientInterceptor
|
||||
grpcOpts []grpc.DialOption
|
||||
@ -166,7 +168,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien
|
||||
unaryClientInterceptor(options.middleware, options.timeout, options.filters),
|
||||
}
|
||||
sints := []grpc.StreamClientInterceptor{
|
||||
streamClientInterceptor(options.filters),
|
||||
streamClientInterceptor(options.streamMiddleware, options.filters),
|
||||
}
|
||||
|
||||
if len(options.ints) > 0 {
|
||||
@ -239,7 +241,54 @@ func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, f
|
||||
}
|
||||
}
|
||||
|
||||
func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInterceptor {
|
||||
// wrappedClientStream wraps the grpc.ClientStream and applies middleware
|
||||
type wrappedClientStream struct {
|
||||
grpc.ClientStream
|
||||
ctx context.Context
|
||||
middleware matcher.Matcher
|
||||
}
|
||||
|
||||
func (w *wrappedClientStream) Context() context.Context {
|
||||
return w.ctx
|
||||
}
|
||||
|
||||
func (w *wrappedClientStream) SendMsg(m interface{}) error {
|
||||
h := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return req, w.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
info, ok := transport.FromClientContext(w.ctx)
|
||||
if !ok {
|
||||
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
|
||||
}
|
||||
|
||||
if next := w.middleware.Match(info.Operation()); len(next) > 0 {
|
||||
h = middleware.Chain(next...)(h)
|
||||
}
|
||||
|
||||
_, err := h(w.ctx, m)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *wrappedClientStream) RecvMsg(m interface{}) error {
|
||||
h := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return req, w.ClientStream.RecvMsg(m)
|
||||
}
|
||||
|
||||
info, ok := transport.FromClientContext(w.ctx)
|
||||
if !ok {
|
||||
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
|
||||
}
|
||||
|
||||
if next := w.middleware.Match(info.Operation()); len(next) > 0 {
|
||||
h = middleware.Chain(next...)(h)
|
||||
}
|
||||
|
||||
_, err := h(w.ctx, m)
|
||||
return err
|
||||
}
|
||||
|
||||
func streamClientInterceptor(ms []middleware.Middleware, filters []selector.NodeFilter) grpc.StreamClientInterceptor {
|
||||
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { // nolint
|
||||
ctx = transport.NewClientContext(ctx, &Transport{
|
||||
endpoint: cc.Target(),
|
||||
@ -249,6 +298,28 @@ func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInt
|
||||
})
|
||||
var p selector.Peer
|
||||
ctx = selector.NewPeerContext(ctx, &p)
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
|
||||
clientStream, err := streamer(ctx, desc, cc, method, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return streamer, nil
|
||||
}
|
||||
|
||||
m := matcher.New()
|
||||
if len(ms) > 0 {
|
||||
m.Use(ms...)
|
||||
middleware.Chain(ms...)(h)
|
||||
}
|
||||
|
||||
wrappedStream := &wrappedClientStream{
|
||||
ClientStream: clientStream,
|
||||
ctx: ctx,
|
||||
middleware: m,
|
||||
}
|
||||
|
||||
return wrappedStream, nil
|
||||
}
|
||||
}
|
||||
|
@ -2,11 +2,13 @@ package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
grpcmd "google.golang.org/grpc/metadata"
|
||||
|
||||
ic "github.com/go-kratos/kratos/v2/internal/context"
|
||||
"github.com/go-kratos/kratos/v2/internal/matcher"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
)
|
||||
@ -48,13 +50,15 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
// wrappedStream is rewrite grpc stream's context
|
||||
type wrappedStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
ctx context.Context
|
||||
middleware matcher.Matcher
|
||||
}
|
||||
|
||||
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream {
|
||||
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream, m matcher.Matcher) grpc.ServerStream {
|
||||
return &wrappedStream{
|
||||
ServerStream: stream,
|
||||
ctx: ctx,
|
||||
middleware: m,
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,7 +80,19 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
|
||||
replyHeader: headerCarrier(replyHeader),
|
||||
})
|
||||
|
||||
ws := NewWrappedStream(ctx, ss)
|
||||
h := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return handler(srv, ss), nil
|
||||
}
|
||||
|
||||
if next := s.streamMiddleware.Match(info.FullMethod); len(next) > 0 {
|
||||
middleware.Chain(next...)(h)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, stream{
|
||||
ServerStream: ss,
|
||||
streamMiddleware: s.streamMiddleware,
|
||||
}, ss)
|
||||
ws := NewWrappedStream(ctx, ss, s.streamMiddleware)
|
||||
|
||||
err := handler(srv, ws)
|
||||
if len(replyHeader) > 0 {
|
||||
@ -85,3 +101,48 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
type stream struct {
|
||||
grpc.ServerStream
|
||||
streamMiddleware matcher.Matcher
|
||||
}
|
||||
|
||||
func GetStream(ctx context.Context) grpc.ServerStream {
|
||||
return ctx.Value(stream{}).(grpc.ServerStream)
|
||||
}
|
||||
|
||||
func (w *wrappedStream) SendMsg(m interface{}) error {
|
||||
h := func(_ context.Context, req interface{}) (interface{}, error) {
|
||||
return req, w.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
info, ok := transport.FromServerContext(w.ctx)
|
||||
if !ok {
|
||||
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
|
||||
}
|
||||
|
||||
if next := w.middleware.Match(info.Operation()); len(next) > 0 {
|
||||
h = middleware.Chain(next...)(h)
|
||||
}
|
||||
|
||||
_, err := h(w.ctx, m)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *wrappedStream) RecvMsg(m interface{}) error {
|
||||
h := func(_ context.Context, req interface{}) (interface{}, error) {
|
||||
return req, w.ServerStream.RecvMsg(m)
|
||||
}
|
||||
|
||||
info, ok := transport.FromServerContext(w.ctx)
|
||||
if !ok {
|
||||
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
|
||||
}
|
||||
|
||||
if next := w.middleware.Match(info.Operation()); len(next) > 0 {
|
||||
h = middleware.Chain(next...)(h)
|
||||
}
|
||||
|
||||
_, err := h(w.ctx, m)
|
||||
return err
|
||||
}
|
||||
|
@ -72,6 +72,12 @@ func Middleware(m ...middleware.Middleware) ServerOption {
|
||||
}
|
||||
}
|
||||
|
||||
func StreamMiddleware(m ...middleware.Middleware) ServerOption {
|
||||
return func(s *Server) {
|
||||
s.streamMiddleware.Use(m...)
|
||||
}
|
||||
}
|
||||
|
||||
// CustomHealth Checks server.
|
||||
func CustomHealth() ServerOption {
|
||||
return func(s *Server) {
|
||||
@ -117,33 +123,35 @@ func Options(opts ...grpc.ServerOption) ServerOption {
|
||||
// Server is a gRPC server wrapper.
|
||||
type Server struct {
|
||||
*grpc.Server
|
||||
baseCtx context.Context
|
||||
tlsConf *tls.Config
|
||||
lis net.Listener
|
||||
err error
|
||||
network string
|
||||
address string
|
||||
endpoint *url.URL
|
||||
timeout time.Duration
|
||||
middleware matcher.Matcher
|
||||
unaryInts []grpc.UnaryServerInterceptor
|
||||
streamInts []grpc.StreamServerInterceptor
|
||||
grpcOpts []grpc.ServerOption
|
||||
health *health.Server
|
||||
customHealth bool
|
||||
metadata *apimd.Server
|
||||
adminClean func()
|
||||
baseCtx context.Context
|
||||
tlsConf *tls.Config
|
||||
lis net.Listener
|
||||
err error
|
||||
network string
|
||||
address string
|
||||
endpoint *url.URL
|
||||
timeout time.Duration
|
||||
middleware matcher.Matcher
|
||||
streamMiddleware matcher.Matcher
|
||||
unaryInts []grpc.UnaryServerInterceptor
|
||||
streamInts []grpc.StreamServerInterceptor
|
||||
grpcOpts []grpc.ServerOption
|
||||
health *health.Server
|
||||
customHealth bool
|
||||
metadata *apimd.Server
|
||||
adminClean func()
|
||||
}
|
||||
|
||||
// NewServer creates a gRPC server by options.
|
||||
func NewServer(opts ...ServerOption) *Server {
|
||||
srv := &Server{
|
||||
baseCtx: context.Background(),
|
||||
network: "tcp",
|
||||
address: ":0",
|
||||
timeout: 1 * time.Second,
|
||||
health: health.NewServer(),
|
||||
middleware: matcher.New(),
|
||||
baseCtx: context.Background(),
|
||||
network: "tcp",
|
||||
address: ":0",
|
||||
timeout: 1 * time.Second,
|
||||
health: health.NewServer(),
|
||||
middleware: matcher.New(),
|
||||
streamMiddleware: matcher.New(),
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(srv)
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/errors"
|
||||
"github.com/go-kratos/kratos/v2/internal/matcher"
|
||||
@ -280,6 +281,82 @@ func TestServer_unaryServerInterceptor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type mockServerStream struct {
|
||||
ctx context.Context
|
||||
sentMsg interface{}
|
||||
recvMsg interface{}
|
||||
metadata metadata.MD
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (m *mockServerStream) SetHeader(md metadata.MD) error {
|
||||
m.metadata = md
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServerStream) SendHeader(md metadata.MD) error {
|
||||
m.metadata = md
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServerStream) SetTrailer(md metadata.MD) {
|
||||
m.metadata = md
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *mockServerStream) SendMsg(msg interface{}) error {
|
||||
m.sentMsg = msg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServerStream) RecvMsg(msg interface{}) error {
|
||||
m.recvMsg = msg
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestServer_streamServerInterceptor(t *testing.T) {
|
||||
u, err := url.Parse("grpc://hello/world")
|
||||
if err != nil {
|
||||
t.Errorf("expect %v, got %v", nil, err)
|
||||
}
|
||||
srv := &Server{
|
||||
baseCtx: context.Background(),
|
||||
endpoint: u,
|
||||
timeout: time.Duration(10),
|
||||
middleware: matcher.New(),
|
||||
streamMiddleware: matcher.New(),
|
||||
}
|
||||
|
||||
srv.streamMiddleware.Use(EmptyMiddleware())
|
||||
|
||||
mockStream := &mockServerStream{
|
||||
ctx: srv.baseCtx,
|
||||
}
|
||||
|
||||
handler := func(_ interface{}, stream grpc.ServerStream) error {
|
||||
resp := &testResp{Data: "stream hi"}
|
||||
return stream.SendMsg(resp)
|
||||
}
|
||||
|
||||
info := &grpc.StreamServerInfo{
|
||||
FullMethod: "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
|
||||
}
|
||||
|
||||
err = srv.streamServerInterceptor()(nil, mockStream, info, handler)
|
||||
if err != nil {
|
||||
t.Errorf("expect %v, got %v", nil, err)
|
||||
}
|
||||
|
||||
// Check response
|
||||
resp := mockStream.sentMsg.(*testResp)
|
||||
if !reflect.DeepEqual("stream hi", resp.Data) {
|
||||
t.Errorf("expect %s, got %s", "stream hi", resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListener(t *testing.T) {
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user