mirror of
https://github.com/go-kratos/kratos.git
synced 2026-05-22 10:15:24 +02:00
revert to select filters (#1656)
This commit is contained in:
@@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/go-kratos/kratos/contrib/registry/consul/v2"
|
"github.com/go-kratos/kratos/contrib/registry/consul/v2"
|
||||||
"github.com/go-kratos/kratos/examples/helloworld/helloworld"
|
"github.com/go-kratos/kratos/examples/helloworld/helloworld"
|
||||||
"github.com/go-kratos/kratos/v2/middleware/recovery"
|
"github.com/go-kratos/kratos/v2/middleware/recovery"
|
||||||
"github.com/go-kratos/kratos/v2/selector"
|
|
||||||
"github.com/go-kratos/kratos/v2/selector/filter"
|
"github.com/go-kratos/kratos/v2/selector/filter"
|
||||||
"github.com/go-kratos/kratos/v2/selector/p2c"
|
"github.com/go-kratos/kratos/v2/selector/p2c"
|
||||||
"github.com/go-kratos/kratos/v2/selector/wrr"
|
"github.com/go-kratos/kratos/v2/selector/wrr"
|
||||||
@@ -32,10 +31,8 @@ func main() {
|
|||||||
// 由于gRPC框架的限制只能使用全局balancer+filter的方式来实现selector
|
// 由于gRPC框架的限制只能使用全局balancer+filter的方式来实现selector
|
||||||
// 这里使用weighted round robin算法的balancer+静态version=1.0.0的Filter
|
// 这里使用weighted round robin算法的balancer+静态version=1.0.0的Filter
|
||||||
grpc.WithBalancerName(wrr.Name),
|
grpc.WithBalancerName(wrr.Name),
|
||||||
grpc.WithNodeFilter(
|
grpc.WithFilter(
|
||||||
func(node selector.Node) bool {
|
filter.Version("1.0.0"),
|
||||||
return node.Version() == "1.0.0"
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No
|
|||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
o(&options)
|
o(&options)
|
||||||
}
|
}
|
||||||
if len(d.Filters) > 0 {
|
if len(d.Filters) > 0 || len(options.Filters) > 0 {
|
||||||
newNodes := make([]Node, len(nodes))
|
newNodes := make([]Node, len(nodes))
|
||||||
for i, wc := range nodes {
|
for i, wc := range nodes {
|
||||||
newNodes[i] = wc
|
newNodes[i] = wc
|
||||||
@@ -35,6 +35,9 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No
|
|||||||
for _, f := range d.Filters {
|
for _, f := range d.Filters {
|
||||||
newNodes = f(ctx, newNodes)
|
newNodes = f(ctx, newNodes)
|
||||||
}
|
}
|
||||||
|
for _, f := range options.Filters {
|
||||||
|
newNodes = f(ctx, newNodes)
|
||||||
|
}
|
||||||
candidates = make([]WeightedNode, len(newNodes))
|
candidates = make([]WeightedNode, len(newNodes))
|
||||||
for i, n := range newNodes {
|
for i, n := range newNodes {
|
||||||
candidates[i] = n.(WeightedNode)
|
candidates[i] = n.(WeightedNode)
|
||||||
@@ -43,9 +46,6 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No
|
|||||||
candidates = nodes
|
candidates = nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(options.Filters) > 0 {
|
|
||||||
candidates = d.nodeFilter(options.Filters, candidates)
|
|
||||||
}
|
|
||||||
if len(candidates) == 0 {
|
if len(candidates) == 0 {
|
||||||
return nil, nil, ErrNoAvailable
|
return nil, nil, ErrNoAvailable
|
||||||
}
|
}
|
||||||
@@ -56,23 +56,6 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No
|
|||||||
return wn.Raw(), done, nil
|
return wn.Raw(), done, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Default) nodeFilter(filters []NodeFilter, nodes []WeightedNode) []WeightedNode {
|
|
||||||
newNodes := make([]WeightedNode, 0, len(nodes))
|
|
||||||
for _, n := range nodes {
|
|
||||||
var remove bool
|
|
||||||
for _, f := range filters {
|
|
||||||
if !f(n) {
|
|
||||||
remove = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !remove {
|
|
||||||
newNodes = append(newNodes, n)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return newNodes
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply update nodes info.
|
// Apply update nodes info.
|
||||||
func (d *Default) Apply(nodes []Node) {
|
func (d *Default) Apply(nodes []Node) {
|
||||||
weightedNodes := make([]WeightedNode, 0, len(nodes))
|
weightedNodes := make([]WeightedNode, 0, len(nodes))
|
||||||
|
|||||||
@@ -4,7 +4,3 @@ import "context"
|
|||||||
|
|
||||||
// Filter is select filter.
|
// Filter is select filter.
|
||||||
type Filter func(context.Context, []Node) []Node
|
type Filter func(context.Context, []Node) []Node
|
||||||
|
|
||||||
// NodeFilter is node filter.
|
|
||||||
// If it returns false, the node will be removed out from the balancer pick list
|
|
||||||
type NodeFilter func(node Node) bool
|
|
||||||
|
|||||||
@@ -9,12 +9,12 @@ import (
|
|||||||
// Version is version filter.
|
// Version is version filter.
|
||||||
func Version(version string) selector.Filter {
|
func Version(version string) selector.Filter {
|
||||||
return func(_ context.Context, nodes []selector.Node) []selector.Node {
|
return func(_ context.Context, nodes []selector.Node) []selector.Node {
|
||||||
filters := make([]selector.Node, 0, len(nodes))
|
newNodes := nodes[:0]
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
if n.Version() == version {
|
if n.Version() == version {
|
||||||
filters = append(filters, n)
|
newNodes = append(newNodes, n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return filters
|
return newNodes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+3
-3
@@ -2,14 +2,14 @@ package selector
|
|||||||
|
|
||||||
// SelectOptions is Select Options.
|
// SelectOptions is Select Options.
|
||||||
type SelectOptions struct {
|
type SelectOptions struct {
|
||||||
Filters []NodeFilter
|
Filters []Filter
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectOption is Selector option.
|
// SelectOption is Selector option.
|
||||||
type SelectOption func(*SelectOptions)
|
type SelectOption func(*SelectOptions)
|
||||||
|
|
||||||
// WithNodeFilter with filter options
|
// WithFilter with filter options
|
||||||
func WithNodeFilter(fn ...NodeFilter) SelectOption {
|
func WithFilter(fn ...Filter) SelectOption {
|
||||||
return func(opts *SelectOptions) {
|
return func(opts *SelectOptions) {
|
||||||
opts.Filters = fn
|
opts.Filters = fn
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package selector
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -51,13 +50,13 @@ func (b *mockWeightedNodeBuilder) Build(n Node) WeightedNode {
|
|||||||
|
|
||||||
func mockFilter(version string) Filter {
|
func mockFilter(version string) Filter {
|
||||||
return func(_ context.Context, nodes []Node) []Node {
|
return func(_ context.Context, nodes []Node) []Node {
|
||||||
filters := make([]Node, 0, len(nodes))
|
newNodes := nodes[:0]
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
if n.Version() == version {
|
if n.Version() == version {
|
||||||
filters = append(filters, n)
|
newNodes = append(newNodes, n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return filters
|
return newNodes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,9 +106,7 @@ func TestDefault(t *testing.T) {
|
|||||||
Metadata: map[string]string{"weight": "10"},
|
Metadata: map[string]string{"weight": "10"},
|
||||||
}))
|
}))
|
||||||
selector.Apply(nodes)
|
selector.Apply(nodes)
|
||||||
n, done, err := selector.Select(context.Background(), WithNodeFilter(func(node Node) bool {
|
n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
|
||||||
return (node.Version() == "v2.0.0")
|
|
||||||
}))
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotNil(t, n)
|
assert.NotNil(t, n)
|
||||||
assert.NotNil(t, done)
|
assert.NotNil(t, done)
|
||||||
@@ -121,74 +118,22 @@ func TestDefault(t *testing.T) {
|
|||||||
done(context.Background(), DoneInfo{})
|
done(context.Background(), DoneInfo{})
|
||||||
|
|
||||||
// no v3.0.0 instance
|
// no v3.0.0 instance
|
||||||
n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool {
|
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0")))
|
||||||
return (node.Version() == "v3.0.0")
|
|
||||||
}))
|
|
||||||
assert.Equal(t, ErrNoAvailable, err)
|
assert.Equal(t, ErrNoAvailable, err)
|
||||||
assert.Nil(t, done)
|
assert.Nil(t, done)
|
||||||
assert.Nil(t, n)
|
assert.Nil(t, n)
|
||||||
|
|
||||||
// apply zero instance
|
// apply zero instance
|
||||||
selector.Apply([]Node{})
|
selector.Apply([]Node{})
|
||||||
n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool {
|
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
|
||||||
return (node.Version() == "v2.0.0")
|
|
||||||
}))
|
|
||||||
assert.Equal(t, ErrNoAvailable, err)
|
assert.Equal(t, ErrNoAvailable, err)
|
||||||
assert.Nil(t, done)
|
assert.Nil(t, done)
|
||||||
assert.Nil(t, n)
|
assert.Nil(t, n)
|
||||||
|
|
||||||
// apply zero instance
|
// apply zero instance
|
||||||
selector.Apply(nil)
|
selector.Apply(nil)
|
||||||
n, done, err = selector.Select(context.Background(), WithNodeFilter(func(node Node) bool {
|
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
|
||||||
return (node.Version() == "v2.0.0")
|
|
||||||
}))
|
|
||||||
assert.Equal(t, ErrNoAvailable, err)
|
assert.Equal(t, ErrNoAvailable, err)
|
||||||
assert.Nil(t, done)
|
assert.Nil(t, done)
|
||||||
assert.Nil(t, n)
|
assert.Nil(t, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNodeFilterWithRandom(t *testing.T) {
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
testBaseFilter(t, 1000, rand.Intn(1000))
|
|
||||||
}
|
|
||||||
|
|
||||||
testBaseFilter(t, 0, rand.Intn(1000))
|
|
||||||
testBaseFilter(t, 1, 1000)
|
|
||||||
testBaseFilter(t, 2, 1000)
|
|
||||||
testBaseFilter(t, 3, 1000)
|
|
||||||
testBaseFilter(t, 1, 0)
|
|
||||||
testBaseFilter(t, 2, 0)
|
|
||||||
testBaseFilter(t, 3, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testBaseFilter(t *testing.T, length int, reservedRatio int) {
|
|
||||||
var raw []WeightedNode
|
|
||||||
var targets map[string]WeightedNode = make(map[string]WeightedNode)
|
|
||||||
for i := 0; i < length; i++ {
|
|
||||||
addr := strconv.FormatInt(int64(i), 10)
|
|
||||||
raw = append(raw, &mockWeightedNode{Node: NewNode(
|
|
||||||
addr,
|
|
||||||
®istry.ServiceInstance{
|
|
||||||
ID: addr,
|
|
||||||
Name: "helloworld",
|
|
||||||
Endpoints: []string{addr},
|
|
||||||
})})
|
|
||||||
if reservedRatio > rand.Intn(length) {
|
|
||||||
targets[addr] = raw[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
f := func(node Node) bool {
|
|
||||||
if _, ok := targets[node.Address()]; ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
d := Default{}
|
|
||||||
raw = d.nodeFilter([]NodeFilter{f}, raw)
|
|
||||||
assert.Equal(t, len(targets), len(raw))
|
|
||||||
for _, n := range raw {
|
|
||||||
_, ok := targets[n.Address()]
|
|
||||||
assert.True(t, ok)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -71,14 +71,14 @@ type Picker struct {
|
|||||||
|
|
||||||
// Pick pick instances.
|
// Pick pick instances.
|
||||||
func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) {
|
func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) {
|
||||||
var filters []selector.NodeFilter
|
var filters []selector.Filter
|
||||||
if tr, ok := transport.FromClientContext(info.Ctx); ok {
|
if tr, ok := transport.FromClientContext(info.Ctx); ok {
|
||||||
if gtr, ok := tr.(*Transport); ok {
|
if gtr, ok := tr.(*Transport); ok {
|
||||||
filters = gtr.NodeFilters()
|
filters = gtr.SelectFilters()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n, done, err := p.selector.Select(info.Ctx, selector.WithNodeFilter(filters...))
|
n, done, err := p.selector.Select(info.Ctx, selector.WithFilter(filters...))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return gBalancer.PickResult{}, err
|
return gBalancer.PickResult{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package grpc
|
package grpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-kratos/kratos/v2/selector"
|
"github.com/go-kratos/kratos/v2/selector"
|
||||||
@@ -24,8 +25,8 @@ func TestBalancerName(t *testing.T) {
|
|||||||
func TestFilters(t *testing.T) {
|
func TestFilters(t *testing.T) {
|
||||||
o := &clientOptions{}
|
o := &clientOptions{}
|
||||||
|
|
||||||
WithNodeFilter(func(selector.Node) bool {
|
WithFilter(func(_ context.Context, nodes []selector.Node) []selector.Node {
|
||||||
return true
|
return nodes
|
||||||
})(o)
|
})(o)
|
||||||
assert.Equal(t, 1, len(o.filters))
|
assert.Equal(t, 1, len(o.filters))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,8 +80,8 @@ func WithBalancerName(name string) ClientOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithNodeFilter with select filters
|
// WithFilter with select filters
|
||||||
func WithNodeFilter(filters ...selector.NodeFilter) ClientOption {
|
func WithFilter(filters ...selector.Filter) ClientOption {
|
||||||
return func(o *clientOptions) {
|
return func(o *clientOptions) {
|
||||||
o.filters = filters
|
o.filters = filters
|
||||||
}
|
}
|
||||||
@@ -97,7 +97,7 @@ type clientOptions struct {
|
|||||||
ints []grpc.UnaryClientInterceptor
|
ints []grpc.UnaryClientInterceptor
|
||||||
grpcOpts []grpc.DialOption
|
grpcOpts []grpc.DialOption
|
||||||
balancerName string
|
balancerName string
|
||||||
filters []selector.NodeFilter
|
filters []selector.Filter
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial returns a GRPC connection.
|
// Dial returns a GRPC connection.
|
||||||
@@ -143,7 +143,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien
|
|||||||
return grpc.DialContext(ctx, options.endpoint, grpcOpts...)
|
return grpc.DialContext(ctx, options.endpoint, grpcOpts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.NodeFilter) grpc.UnaryClientInterceptor {
|
func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.Filter) grpc.UnaryClientInterceptor {
|
||||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||||
ctx = transport.NewClientContext(ctx, &Transport{
|
ctx = transport.NewClientContext(ctx, &Transport{
|
||||||
endpoint: cc.Target(),
|
endpoint: cc.Target(),
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ type Transport struct {
|
|||||||
operation string
|
operation string
|
||||||
reqHeader headerCarrier
|
reqHeader headerCarrier
|
||||||
replyHeader headerCarrier
|
replyHeader headerCarrier
|
||||||
filters []selector.NodeFilter
|
filters []selector.Filter
|
||||||
}
|
}
|
||||||
|
|
||||||
// Kind returns the transport kind.
|
// Kind returns the transport kind.
|
||||||
@@ -42,8 +42,8 @@ func (tr *Transport) ReplyHeader() transport.Header {
|
|||||||
return tr.replyHeader
|
return tr.replyHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filters returns the client select filters.
|
// SelectFilters returns the client select filters.
|
||||||
func (tr *Transport) NodeFilters() []selector.NodeFilter {
|
func (tr *Transport) SelectFilters() []selector.Filter {
|
||||||
return tr.filters
|
return tr.filters
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user