mirror of
https://github.com/go-kratos/kratos.git
synced 2025-03-17 21:07:54 +02:00
feat(middleware): add selector matcher (#2239)
* feat(middleware): add selector matcher Co-authored-by: chenzhihui <chenzhihui@bilibili.com>
This commit is contained in:
parent
377356d04d
commit
f3b0da3f04
62
internal/matcher/middleware.go
Normal file
62
internal/matcher/middleware.go
Normal file
@ -0,0 +1,62 @@
|
||||
package matcher
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
)
|
||||
|
||||
// Matcher is a middleware matcher.
|
||||
type Matcher interface {
|
||||
Use(ms ...middleware.Middleware)
|
||||
Add(selector string, ms ...middleware.Middleware)
|
||||
Match(operation string) []middleware.Middleware
|
||||
}
|
||||
|
||||
// New new a middleware matcher.
|
||||
func New() Matcher {
|
||||
return &matcher{
|
||||
matchs: make(map[string][]middleware.Middleware),
|
||||
}
|
||||
}
|
||||
|
||||
type matcher struct {
|
||||
prefix []string
|
||||
defaults []middleware.Middleware
|
||||
matchs map[string][]middleware.Middleware
|
||||
}
|
||||
|
||||
func (m *matcher) Use(ms ...middleware.Middleware) {
|
||||
m.defaults = ms
|
||||
}
|
||||
|
||||
func (m *matcher) Add(selector string, ms ...middleware.Middleware) {
|
||||
if strings.HasSuffix(selector, "*") {
|
||||
selector = strings.TrimSuffix(selector, "*")
|
||||
m.prefix = append(m.prefix, selector)
|
||||
// sort the prefix:
|
||||
// - /foo/bar
|
||||
// - /foo
|
||||
sort.Slice(m.prefix, func(i, j int) bool {
|
||||
return m.prefix[i] > m.prefix[j]
|
||||
})
|
||||
}
|
||||
m.matchs[selector] = ms
|
||||
}
|
||||
|
||||
func (m *matcher) Match(operation string) []middleware.Middleware {
|
||||
ms := make([]middleware.Middleware, 0, len(m.defaults))
|
||||
if len(m.defaults) > 0 {
|
||||
ms = append(ms, m.defaults...)
|
||||
}
|
||||
if next, ok := m.matchs[operation]; ok {
|
||||
return append(ms, next...)
|
||||
}
|
||||
for _, prefix := range m.prefix {
|
||||
if strings.HasPrefix(operation, prefix) {
|
||||
return append(ms, m.matchs[prefix]...)
|
||||
}
|
||||
}
|
||||
return ms
|
||||
}
|
62
internal/matcher/middleware_test.go
Normal file
62
internal/matcher/middleware_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
package matcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
)
|
||||
|
||||
func logging(module string) middleware.Middleware {
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
|
||||
return module, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func equal(ms []middleware.Middleware, modules ...string) bool {
|
||||
if len(ms) == 0 {
|
||||
return false
|
||||
}
|
||||
for i, m := range ms {
|
||||
x, _ := m(nil)(nil, nil)
|
||||
if x != modules[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMatcher(t *testing.T) {
|
||||
m := New()
|
||||
m.Use(logging("logging"))
|
||||
m.Add("*", logging("*"))
|
||||
m.Add("/foo/*", logging("foo/*"))
|
||||
m.Add("/foo/bar/*", logging("foo/bar/*"))
|
||||
m.Add("/foo/bar", logging("foo/bar"))
|
||||
|
||||
if ms := m.Match("/"); len(ms) != 2 {
|
||||
t.Fatal("not equal")
|
||||
} else if !equal(ms, "logging", "*") {
|
||||
t.Fatal("not equal")
|
||||
}
|
||||
|
||||
if ms := m.Match("/foo/xxx"); len(ms) != 2 {
|
||||
t.Fatal("not equal")
|
||||
} else if !equal(ms, "logging", "foo/*") {
|
||||
t.Fatal("not equal")
|
||||
}
|
||||
|
||||
if ms := m.Match("/foo/bar"); len(ms) != 2 {
|
||||
t.Fatal("not equal")
|
||||
} else if !equal(ms, "logging", "foo/bar") {
|
||||
t.Fatal("not equal")
|
||||
}
|
||||
|
||||
if ms := m.Match("/foo/bar/x"); len(ms) != 2 {
|
||||
t.Fatal("not equal")
|
||||
} else if !equal(ms, "logging", "foo/bar/*") {
|
||||
t.Fatal("not equal")
|
||||
}
|
||||
}
|
@ -33,8 +33,8 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
h := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
if len(s.middleware) > 0 {
|
||||
h = middleware.Chain(s.middleware...)(h)
|
||||
if next := s.middleware.Match(tr.Operation()); len(next) > 0 {
|
||||
h = middleware.Chain(next...)(h)
|
||||
}
|
||||
reply, err := h(ctx, req)
|
||||
if len(replyHeader) > 0 {
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/internal/endpoint"
|
||||
"github.com/go-kratos/kratos/v2/internal/matcher"
|
||||
|
||||
apimd "github.com/go-kratos/kratos/v2/api/metadata"
|
||||
|
||||
@ -62,7 +63,7 @@ func Logger(logger log.Logger) ServerOption {
|
||||
// Middleware with server middleware.
|
||||
func Middleware(m ...middleware.Middleware) ServerOption {
|
||||
return func(s *Server) {
|
||||
s.middleware = m
|
||||
s.middleware.Use(m...)
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,7 +113,7 @@ type Server struct {
|
||||
address string
|
||||
endpoint *url.URL
|
||||
timeout time.Duration
|
||||
middleware []middleware.Middleware
|
||||
middleware matcher.Matcher
|
||||
unaryInts []grpc.UnaryServerInterceptor
|
||||
streamInts []grpc.StreamServerInterceptor
|
||||
grpcOpts []grpc.ServerOption
|
||||
@ -123,11 +124,12 @@ type Server struct {
|
||||
// NewServer creates a gRPC server by options.
|
||||
func NewServer(opts ...ServerOption) *Server {
|
||||
srv := &Server{
|
||||
baseCtx: context.Background(),
|
||||
network: "tcp",
|
||||
address: ":0",
|
||||
timeout: 1 * time.Second,
|
||||
health: health.NewServer(),
|
||||
baseCtx: context.Background(),
|
||||
network: "tcp",
|
||||
address: ":0",
|
||||
timeout: 1 * time.Second,
|
||||
health: health.NewServer(),
|
||||
middleware: matcher.New(),
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(srv)
|
||||
@ -163,6 +165,15 @@ func NewServer(opts ...ServerOption) *Server {
|
||||
return srv
|
||||
}
|
||||
|
||||
// Use uses a service middleware with selector.
|
||||
// selector:
|
||||
// - '/*'
|
||||
// - '/helloworld.v1.Greeter/*'
|
||||
// - '/helloworld.v1.Greeter/SayHello'
|
||||
func (s *Server) Use(selector string, m ...middleware.Middleware) {
|
||||
s.middleware.Add(selector, m...)
|
||||
}
|
||||
|
||||
// Endpoint return a real address to registry endpoint.
|
||||
// examples:
|
||||
// grpc://127.0.0.1:9000?isSecure=false
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/errors"
|
||||
"github.com/go-kratos/kratos/v2/internal/matcher"
|
||||
pb "github.com/go-kratos/kratos/v2/internal/testdata/helloworld"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
@ -198,17 +199,6 @@ func TestTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
o := &Server{}
|
||||
v := []middleware.Middleware{
|
||||
func(middleware.Handler) middleware.Handler { return nil },
|
||||
}
|
||||
Middleware(v...)(o)
|
||||
if !reflect.DeepEqual(v, o.middleware) {
|
||||
t.Errorf("expect %v, got %v", v, o.middleware)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSConfig(t *testing.T) {
|
||||
o := &Server{}
|
||||
v := &tls.Config{}
|
||||
@ -273,9 +263,10 @@ func TestServer_unaryServerInterceptor(t *testing.T) {
|
||||
srv := &Server{
|
||||
baseCtx: context.Background(),
|
||||
endpoint: u,
|
||||
middleware: []middleware.Middleware{EmptyMiddleware()},
|
||||
timeout: time.Duration(10),
|
||||
middleware: matcher.New(),
|
||||
}
|
||||
srv.middleware.Use(EmptyMiddleware())
|
||||
req := &struct{}{}
|
||||
rv, err := srv.unaryServerInterceptor()(context.TODO(), req, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (i interface{}, e error) {
|
||||
return &testResp{Data: "hi"}, nil
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
"github.com/go-kratos/kratos/v2/transport/http/binding"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
@ -89,7 +90,10 @@ func (c *wrapper) Query() url.Values {
|
||||
func (c *wrapper) Request() *http.Request { return c.req }
|
||||
func (c *wrapper) Response() http.ResponseWriter { return c.res }
|
||||
func (c *wrapper) Middleware(h middleware.Handler) middleware.Handler {
|
||||
return middleware.Chain(c.router.srv.ms...)(h)
|
||||
if tr, ok := transport.FromServerContext(c.req.Context()); ok {
|
||||
return middleware.Chain(c.router.srv.middleware.Match(tr.Operation())...)(h)
|
||||
}
|
||||
return middleware.Chain(c.router.srv.middleware.Match(c.req.URL.Path)...)(h)
|
||||
}
|
||||
func (c *wrapper) Bind(v interface{}) error { return c.router.srv.dec(c.req, v) }
|
||||
func (c *wrapper) BindVars(v interface{}) error { return binding.BindQuery(c.Vars(), v) }
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/internal/endpoint"
|
||||
"github.com/go-kratos/kratos/v2/internal/matcher"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/internal/host"
|
||||
"github.com/go-kratos/kratos/v2/log"
|
||||
@ -58,7 +59,7 @@ func Logger(logger log.Logger) ServerOption {
|
||||
// Middleware with service middleware option.
|
||||
func Middleware(m ...middleware.Middleware) ServerOption {
|
||||
return func(o *Server) {
|
||||
o.ms = m
|
||||
o.middleware.Use(m...)
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,7 +125,7 @@ type Server struct {
|
||||
address string
|
||||
timeout time.Duration
|
||||
filters []FilterFunc
|
||||
ms []middleware.Middleware
|
||||
middleware matcher.Matcher
|
||||
dec DecodeRequestFunc
|
||||
enc EncodeResponseFunc
|
||||
ene EncodeErrorFunc
|
||||
@ -138,6 +139,7 @@ func NewServer(opts ...ServerOption) *Server {
|
||||
network: "tcp",
|
||||
address: ":0",
|
||||
timeout: 1 * time.Second,
|
||||
middleware: matcher.New(),
|
||||
dec: DefaultRequestDecoder,
|
||||
enc: DefaultResponseEncoder,
|
||||
ene: DefaultErrorEncoder,
|
||||
@ -157,6 +159,15 @@ func NewServer(opts ...ServerOption) *Server {
|
||||
return srv
|
||||
}
|
||||
|
||||
// Use uses a service middleware with selector.
|
||||
// selector:
|
||||
// - '/*'
|
||||
// - '/helloworld.v1.Greeter/*'
|
||||
// - '/helloworld.v1.Greeter/SayHello'
|
||||
func (s *Server) Use(selector string, m ...middleware.Middleware) {
|
||||
s.middleware.Add(selector, m...)
|
||||
}
|
||||
|
||||
// WalkRoute walks the router and all its sub-routers, calling walkFn for each route in the tree.
|
||||
func (s *Server) WalkRoute(fn WalkRouteFunc) error {
|
||||
return s.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
|
||||
@ -229,10 +240,10 @@ func (s *Server) filter() mux.MiddlewareFunc {
|
||||
|
||||
tr := &Transport{
|
||||
operation: pathTemplate,
|
||||
pathTemplate: pathTemplate,
|
||||
reqHeader: headerCarrier(req.Header),
|
||||
replyHeader: headerCarrier(w.Header()),
|
||||
request: req,
|
||||
pathTemplate: pathTemplate,
|
||||
}
|
||||
if s.endpoint != nil {
|
||||
tr.endpoint = s.endpoint.String()
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/errors"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/internal/host"
|
||||
)
|
||||
@ -313,17 +312,6 @@ func TestLogger(t *testing.T) {
|
||||
// todo
|
||||
}
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
o := &Server{}
|
||||
v := []middleware.Middleware{
|
||||
func(middleware.Handler) middleware.Handler { return nil },
|
||||
}
|
||||
Middleware(v...)(o)
|
||||
if !reflect.DeepEqual(v, o.ms) {
|
||||
t.Errorf("expected %v got %v", v, o.ms)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestDecoder(t *testing.T) {
|
||||
o := &Server{}
|
||||
v := func(*http.Request, interface{}) error { return nil }
|
||||
|
Loading…
x
Reference in New Issue
Block a user