1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-12 02:28:05 +02:00

feat: adding stream interceptor for logging middleware (#3359)

This commit is contained in:
Abhishek koserwal 2024-09-18 07:29:45 +05:30 committed by GitHub
parent 908e6256a9
commit e1f5dc42b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 245 additions and 28 deletions

View File

@ -11,6 +11,7 @@ import (
grpcinsecure "google.golang.org/grpc/credentials/insecure"
grpcmd "google.golang.org/grpc/metadata"
"github.com/go-kratos/kratos/v2/internal/matcher"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry"
@ -132,6 +133,7 @@ type clientOptions struct {
timeout time.Duration
discovery registry.Discovery
middleware []middleware.Middleware
streamMiddleware []middleware.Middleware
ints []grpc.UnaryClientInterceptor
streamInts []grpc.StreamClientInterceptor
grpcOpts []grpc.DialOption
@ -166,7 +168,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien
unaryClientInterceptor(options.middleware, options.timeout, options.filters),
}
sints := []grpc.StreamClientInterceptor{
streamClientInterceptor(options.filters),
streamClientInterceptor(options.streamMiddleware, options.filters),
}
if len(options.ints) > 0 {
@ -239,7 +241,54 @@ func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, f
}
}
func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInterceptor {
// wrappedClientStream wraps the grpc.ClientStream and applies middleware
type wrappedClientStream struct {
grpc.ClientStream
ctx context.Context
middleware matcher.Matcher
}
func (w *wrappedClientStream) Context() context.Context {
return w.ctx
}
func (w *wrappedClientStream) SendMsg(m interface{}) error {
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return req, w.ClientStream.SendMsg(m)
}
info, ok := transport.FromClientContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}
if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}
_, err := h(w.ctx, m)
return err
}
func (w *wrappedClientStream) RecvMsg(m interface{}) error {
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return req, w.ClientStream.RecvMsg(m)
}
info, ok := transport.FromClientContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}
if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}
_, err := h(w.ctx, m)
return err
}
func streamClientInterceptor(ms []middleware.Middleware, filters []selector.NodeFilter) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { // nolint
ctx = transport.NewClientContext(ctx, &Transport{
endpoint: cc.Target(),
@ -249,6 +298,28 @@ func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInt
})
var p selector.Peer
ctx = selector.NewPeerContext(ctx, &p)
return streamer(ctx, desc, cc, method, opts...)
clientStream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, err
}
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return streamer, nil
}
m := matcher.New()
if len(ms) > 0 {
m.Use(ms...)
middleware.Chain(ms...)(h)
}
wrappedStream := &wrappedClientStream{
ClientStream: clientStream,
ctx: ctx,
middleware: m,
}
return wrappedStream, nil
}
}

View File

@ -2,11 +2,13 @@ package grpc
import (
"context"
"fmt"
"google.golang.org/grpc"
grpcmd "google.golang.org/grpc/metadata"
ic "github.com/go-kratos/kratos/v2/internal/context"
"github.com/go-kratos/kratos/v2/internal/matcher"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
)
@ -48,13 +50,15 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
// wrappedStream is rewrite grpc stream's context
type wrappedStream struct {
grpc.ServerStream
ctx context.Context
ctx context.Context
middleware matcher.Matcher
}
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream {
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream, m matcher.Matcher) grpc.ServerStream {
return &wrappedStream{
ServerStream: stream,
ctx: ctx,
middleware: m,
}
}
@ -76,7 +80,19 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
replyHeader: headerCarrier(replyHeader),
})
ws := NewWrappedStream(ctx, ss)
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return handler(srv, ss), nil
}
if next := s.streamMiddleware.Match(info.FullMethod); len(next) > 0 {
middleware.Chain(next...)(h)
}
ctx = context.WithValue(ctx, stream{
ServerStream: ss,
streamMiddleware: s.streamMiddleware,
}, ss)
ws := NewWrappedStream(ctx, ss, s.streamMiddleware)
err := handler(srv, ws)
if len(replyHeader) > 0 {
@ -85,3 +101,48 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
return err
}
}
type stream struct {
grpc.ServerStream
streamMiddleware matcher.Matcher
}
func GetStream(ctx context.Context) grpc.ServerStream {
return ctx.Value(stream{}).(grpc.ServerStream)
}
func (w *wrappedStream) SendMsg(m interface{}) error {
h := func(_ context.Context, req interface{}) (interface{}, error) {
return req, w.ServerStream.SendMsg(m)
}
info, ok := transport.FromServerContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}
if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}
_, err := h(w.ctx, m)
return err
}
func (w *wrappedStream) RecvMsg(m interface{}) error {
h := func(_ context.Context, req interface{}) (interface{}, error) {
return req, w.ServerStream.RecvMsg(m)
}
info, ok := transport.FromServerContext(w.ctx)
if !ok {
return fmt.Errorf("transport value stored in ctx returns: %v", ok)
}
if next := w.middleware.Match(info.Operation()); len(next) > 0 {
h = middleware.Chain(next...)(h)
}
_, err := h(w.ctx, m)
return err
}

View File

@ -72,6 +72,12 @@ func Middleware(m ...middleware.Middleware) ServerOption {
}
}
func StreamMiddleware(m ...middleware.Middleware) ServerOption {
return func(s *Server) {
s.streamMiddleware.Use(m...)
}
}
// CustomHealth Checks server.
func CustomHealth() ServerOption {
return func(s *Server) {
@ -117,33 +123,35 @@ func Options(opts ...grpc.ServerOption) ServerOption {
// Server is a gRPC server wrapper.
type Server struct {
*grpc.Server
baseCtx context.Context
tlsConf *tls.Config
lis net.Listener
err error
network string
address string
endpoint *url.URL
timeout time.Duration
middleware matcher.Matcher
unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption
health *health.Server
customHealth bool
metadata *apimd.Server
adminClean func()
baseCtx context.Context
tlsConf *tls.Config
lis net.Listener
err error
network string
address string
endpoint *url.URL
timeout time.Duration
middleware matcher.Matcher
streamMiddleware matcher.Matcher
unaryInts []grpc.UnaryServerInterceptor
streamInts []grpc.StreamServerInterceptor
grpcOpts []grpc.ServerOption
health *health.Server
customHealth bool
metadata *apimd.Server
adminClean func()
}
// NewServer creates a gRPC server by options.
func NewServer(opts ...ServerOption) *Server {
srv := &Server{
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
middleware: matcher.New(),
baseCtx: context.Background(),
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
health: health.NewServer(),
middleware: matcher.New(),
streamMiddleware: matcher.New(),
}
for _, o := range opts {
o(srv)

View File

@ -12,6 +12,7 @@ import (
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/internal/matcher"
@ -280,6 +281,82 @@ func TestServer_unaryServerInterceptor(t *testing.T) {
}
}
type mockServerStream struct {
ctx context.Context
sentMsg interface{}
recvMsg interface{}
metadata metadata.MD
grpc.ServerStream
}
func (m *mockServerStream) SetHeader(md metadata.MD) error {
m.metadata = md
return nil
}
func (m *mockServerStream) SendHeader(md metadata.MD) error {
m.metadata = md
return nil
}
func (m *mockServerStream) SetTrailer(md metadata.MD) {
m.metadata = md
}
func (m *mockServerStream) Context() context.Context {
return m.ctx
}
func (m *mockServerStream) SendMsg(msg interface{}) error {
m.sentMsg = msg
return nil
}
func (m *mockServerStream) RecvMsg(msg interface{}) error {
m.recvMsg = msg
return nil
}
func TestServer_streamServerInterceptor(t *testing.T) {
u, err := url.Parse("grpc://hello/world")
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
srv := &Server{
baseCtx: context.Background(),
endpoint: u,
timeout: time.Duration(10),
middleware: matcher.New(),
streamMiddleware: matcher.New(),
}
srv.streamMiddleware.Use(EmptyMiddleware())
mockStream := &mockServerStream{
ctx: srv.baseCtx,
}
handler := func(_ interface{}, stream grpc.ServerStream) error {
resp := &testResp{Data: "stream hi"}
return stream.SendMsg(resp)
}
info := &grpc.StreamServerInfo{
FullMethod: "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
}
err = srv.streamServerInterceptor()(nil, mockStream, info, handler)
if err != nil {
t.Errorf("expect %v, got %v", nil, err)
}
// Check response
resp := mockStream.sentMsg.(*testResp)
if !reflect.DeepEqual("stream hi", resp.Data) {
t.Errorf("expect %s, got %s", "stream hi", resp.Data)
}
}
func TestListener(t *testing.T) {
lis, err := net.Listen("tcp", ":0")
if err != nil {