mirror of
				https://github.com/go-kratos/kratos.git
				synced 2025-10-30 23:47:59 +02:00 
			
		
		
		
	refactor: unify selector filter (#2277)
* unify selector Co-authored-by: caoguoliang01 <caoguoliang01@bilibili.com> Co-authored-by: chenzhihui <zhihui_chen@foxmail.com>
This commit is contained in:
		| @@ -9,7 +9,6 @@ import ( | ||||
| type Default struct { | ||||
| 	NodeBuilder WeightedNodeBuilder | ||||
| 	Balancer    Balancer | ||||
| 	Filters     []Filter | ||||
|  | ||||
| 	nodes atomic.Value | ||||
| } | ||||
| @@ -27,16 +26,13 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
| 	if len(d.Filters) > 0 || len(options.Filters) > 0 { | ||||
| 	if len(options.NodeFilters) > 0 { | ||||
| 		newNodes := make([]Node, len(nodes)) | ||||
| 		for i, wc := range nodes { | ||||
| 			newNodes[i] = wc | ||||
| 		} | ||||
| 		for _, f := range d.Filters { | ||||
| 			newNodes = f(ctx, newNodes) | ||||
| 		} | ||||
| 		for _, f := range options.Filters { | ||||
| 			newNodes = f(ctx, newNodes) | ||||
| 		for _, filter := range options.NodeFilters { | ||||
| 			newNodes = filter(ctx, newNodes) | ||||
| 		} | ||||
| 		candidates = make([]WeightedNode, len(newNodes)) | ||||
| 		for i, n := range newNodes { | ||||
| @@ -74,7 +70,6 @@ func (d *Default) Apply(nodes []Node) { | ||||
| type DefaultBuilder struct { | ||||
| 	Node     WeightedNodeBuilder | ||||
| 	Balancer BalancerBuilder | ||||
| 	Filters  []Filter | ||||
| } | ||||
|  | ||||
| // Build create builder | ||||
| @@ -82,6 +77,5 @@ func (db *DefaultBuilder) Build() Selector { | ||||
| 	return &Default{ | ||||
| 		NodeBuilder: db.Node, | ||||
| 		Balancer:    db.Balancer.Build(), | ||||
| 		Filters:     db.Filters, | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -2,5 +2,5 @@ package selector | ||||
|  | ||||
| import "context" | ||||
|  | ||||
| // Filter is select filter. | ||||
| type Filter func(context.Context, []Node) []Node | ||||
| // NodeFilter is select filter. | ||||
| type NodeFilter func(context.Context, []Node) []Node | ||||
|   | ||||
| @@ -7,7 +7,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| // Version is version filter. | ||||
| func Version(version string) selector.Filter { | ||||
| func Version(version string) selector.NodeFilter { | ||||
| 	return func(_ context.Context, nodes []selector.Node) []selector.Node { | ||||
| 		newNodes := nodes[:0] | ||||
| 		for _, n := range nodes { | ||||
|   | ||||
							
								
								
									
										13
									
								
								selector/global.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								selector/global.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| package selector | ||||
|  | ||||
| var globalSelector Builder | ||||
|  | ||||
| // GlobalSelector returns global selector builder. | ||||
| func GlobalSelector() Builder { | ||||
| 	return globalSelector | ||||
| } | ||||
|  | ||||
| // SetGlobalSelector set global selector builder. | ||||
| func SetGlobalSelector(builder Builder) { | ||||
| 	globalSelector = builder | ||||
| } | ||||
| @@ -2,15 +2,15 @@ package selector | ||||
|  | ||||
| // SelectOptions is Select Options. | ||||
| type SelectOptions struct { | ||||
| 	Filters []Filter | ||||
| 	NodeFilters []NodeFilter | ||||
| } | ||||
|  | ||||
| // SelectOption is Selector option. | ||||
| type SelectOption func(*SelectOptions) | ||||
|  | ||||
| // WithFilter with filter options | ||||
| func WithFilter(fn ...Filter) SelectOption { | ||||
| // WithNodeFilter with filter options | ||||
| func WithNodeFilter(fn ...NodeFilter) SelectOption { | ||||
| 	return func(opts *SelectOptions) { | ||||
| 		opts.Filters = fn | ||||
| 		opts.NodeFilters = fn | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -19,20 +19,11 @@ const ( | ||||
|  | ||||
| var _ selector.Balancer = &Balancer{} | ||||
|  | ||||
| // WithFilter with select filters | ||||
| func WithFilter(filters ...selector.Filter) Option { | ||||
| 	return func(o *options) { | ||||
| 		o.filters = filters | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Option is random builder option. | ||||
| type Option func(o *options) | ||||
|  | ||||
| // options is random builder options | ||||
| type options struct { | ||||
| 	filters []selector.Filter | ||||
| } | ||||
| type options struct{} | ||||
|  | ||||
| // New creates a p2c selector. | ||||
| func New(opts ...Option) selector.Selector { | ||||
| @@ -95,7 +86,6 @@ func NewBuilder(opts ...Option) selector.Builder { | ||||
| 		opt(&option) | ||||
| 	} | ||||
| 	return &selector.DefaultBuilder{ | ||||
| 		Filters:  option.filters, | ||||
| 		Balancer: &Builder{}, | ||||
| 		Node:     &ewma.Builder{}, | ||||
| 	} | ||||
|   | ||||
| @@ -16,7 +16,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| func TestWrr3(t *testing.T) { | ||||
| 	p2c := New(WithFilter(filter.Version("v2.0.0"))) | ||||
| 	p2c := New() | ||||
| 	var nodes []selector.Node | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		addr := fmt.Sprintf("127.0.0.%d:8080", i) | ||||
| @@ -41,7 +41,7 @@ func TestWrr3(t *testing.T) { | ||||
| 			d := time.Duration(rand.Intn(500)) * time.Millisecond | ||||
| 			lk.Unlock() | ||||
| 			time.Sleep(d) | ||||
| 			n, done, err := p2c.Select(context.Background()) | ||||
| 			n, done, err := p2c.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0"))) | ||||
| 			if err != nil { | ||||
| 				t.Errorf("expect %v, got %v", nil, err) | ||||
| 			} | ||||
| @@ -92,7 +92,7 @@ func TestEmpty(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestOne(t *testing.T) { | ||||
| 	p2c := New(WithFilter(filter.Version("v2.0.0"))) | ||||
| 	p2c := New() | ||||
| 	var nodes []selector.Node | ||||
| 	for i := 0; i < 1; i++ { | ||||
| 		addr := fmt.Sprintf("127.0.0.%d:8080", i) | ||||
| @@ -106,7 +106,7 @@ func TestOne(t *testing.T) { | ||||
| 			})) | ||||
| 	} | ||||
| 	p2c.Apply(nodes) | ||||
| 	n, done, err := p2c.Select(context.Background()) | ||||
| 	n, done, err := p2c.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0"))) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("expect %v, got %v", nil, err) | ||||
| 	} | ||||
|   | ||||
| @@ -15,20 +15,11 @@ const ( | ||||
|  | ||||
| var _ selector.Balancer = &Balancer{} // Name is balancer name | ||||
|  | ||||
| // WithFilter with select filters | ||||
| func WithFilter(filters ...selector.Filter) Option { | ||||
| 	return func(o *options) { | ||||
| 		o.filters = filters | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Option is random builder option. | ||||
| type Option func(o *options) | ||||
|  | ||||
| // options is random builder options | ||||
| type options struct { | ||||
| 	filters []selector.Filter | ||||
| } | ||||
| type options struct{} | ||||
|  | ||||
| // Balancer is a random balancer. | ||||
| type Balancer struct{} | ||||
| @@ -56,7 +47,6 @@ func NewBuilder(opts ...Option) selector.Builder { | ||||
| 		opt(&option) | ||||
| 	} | ||||
| 	return &selector.DefaultBuilder{ | ||||
| 		Filters:  option.filters, | ||||
| 		Balancer: &Builder{}, | ||||
| 		Node:     &direct.Builder{}, | ||||
| 	} | ||||
|   | ||||
| @@ -10,7 +10,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| func TestWrr(t *testing.T) { | ||||
| 	random := New(WithFilter(filter.Version("v2.0.0"))) | ||||
| 	random := New() | ||||
| 	var nodes []selector.Node | ||||
| 	nodes = append(nodes, selector.NewNode( | ||||
| 		"http", | ||||
| @@ -31,7 +31,7 @@ func TestWrr(t *testing.T) { | ||||
| 	random.Apply(nodes) | ||||
| 	var count1, count2 int | ||||
| 	for i := 0; i < 200; i++ { | ||||
| 		n, done, err := random.Select(context.Background()) | ||||
| 		n, done, err := random.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0"))) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expect no error, got %v", err) | ||||
| 		} | ||||
|   | ||||
| @@ -57,7 +57,7 @@ type DoneInfo struct { | ||||
| 	// Response Error | ||||
| 	Err error | ||||
| 	// Response Metadata | ||||
| 	ReplyMeta ReplyMeta | ||||
| 	ReplyMD ReplyMD | ||||
|  | ||||
| 	// BytesSent indicates if any bytes have been sent to the server. | ||||
| 	BytesSent bool | ||||
| @@ -65,8 +65,8 @@ type DoneInfo struct { | ||||
| 	BytesReceived bool | ||||
| } | ||||
|  | ||||
| // ReplyMeta is Reply Metadata. | ||||
| type ReplyMeta interface { | ||||
| // ReplyMD is Reply Metadata. | ||||
| type ReplyMD interface { | ||||
| 	Get(key string) string | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -49,7 +49,7 @@ func (b *mockWeightedNodeBuilder) Build(n Node) WeightedNode { | ||||
| 	return &mockWeightedNode{Node: n} | ||||
| } | ||||
|  | ||||
| func mockFilter(version string) Filter { | ||||
| func mockFilter(version string) NodeFilter { | ||||
| 	return func(_ context.Context, nodes []Node) []Node { | ||||
| 		newNodes := nodes[:0] | ||||
| 		for _, n := range nodes { | ||||
| @@ -83,7 +83,6 @@ func (b *mockBalancer) Pick(ctx context.Context, nodes []WeightedNode) (selected | ||||
| func TestDefault(t *testing.T) { | ||||
| 	builder := DefaultBuilder{ | ||||
| 		Node:     &mockWeightedNodeBuilder{}, | ||||
| 		Filters:  []Filter{mockFilter("v2.0.0")}, | ||||
| 		Balancer: &mockBalancerBuilder{}, | ||||
| 	} | ||||
| 	selector := builder.Build() | ||||
| @@ -109,7 +108,7 @@ func TestDefault(t *testing.T) { | ||||
| 			Metadata:  map[string]string{"weight": "10"}, | ||||
| 		})) | ||||
| 	selector.Apply(nodes) | ||||
| 	n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) | ||||
| 	n, done, err := selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0"))) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("expect %v, got %v", nil, err) | ||||
| 	} | ||||
| @@ -137,7 +136,7 @@ func TestDefault(t *testing.T) { | ||||
| 	done(context.Background(), DoneInfo{}) | ||||
|  | ||||
| 	// no v3.0.0 instance | ||||
| 	n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0"))) | ||||
| 	n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v3.0.0"))) | ||||
| 	if !errors.Is(ErrNoAvailable, err) { | ||||
| 		t.Errorf("expect %v, got %v", ErrNoAvailable, err) | ||||
| 	} | ||||
| @@ -150,7 +149,7 @@ func TestDefault(t *testing.T) { | ||||
|  | ||||
| 	// apply zero instance | ||||
| 	selector.Apply([]Node{}) | ||||
| 	n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) | ||||
| 	n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0"))) | ||||
| 	if !errors.Is(ErrNoAvailable, err) { | ||||
| 		t.Errorf("expect %v, got %v", ErrNoAvailable, err) | ||||
| 	} | ||||
| @@ -163,7 +162,7 @@ func TestDefault(t *testing.T) { | ||||
|  | ||||
| 	// apply zero instance | ||||
| 	selector.Apply(nil) | ||||
| 	n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0"))) | ||||
| 	n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0"))) | ||||
| 	if !errors.Is(ErrNoAvailable, err) { | ||||
| 		t.Errorf("expect %v, got %v", ErrNoAvailable, err) | ||||
| 	} | ||||
|   | ||||
| @@ -15,20 +15,11 @@ const ( | ||||
|  | ||||
| var _ selector.Balancer = &Balancer{} // Name is balancer name | ||||
|  | ||||
| // WithFilter with select filters | ||||
| func WithFilter(filters ...selector.Filter) Option { | ||||
| 	return func(o *options) { | ||||
| 		o.filters = filters | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Option is random builder option. | ||||
| type Option func(o *options) | ||||
|  | ||||
| // options is random builder options | ||||
| type options struct { | ||||
| 	filters []selector.Filter | ||||
| } | ||||
| type options struct{} | ||||
|  | ||||
| // Balancer is a random balancer. | ||||
| type Balancer struct { | ||||
| @@ -77,7 +68,6 @@ func NewBuilder(opts ...Option) selector.Builder { | ||||
| 		opt(&option) | ||||
| 	} | ||||
| 	return &selector.DefaultBuilder{ | ||||
| 		Filters:  option.filters, | ||||
| 		Balancer: &Builder{}, | ||||
| 		Node:     &direct.Builder{}, | ||||
| 	} | ||||
|   | ||||
| @@ -11,7 +11,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| func TestWrr(t *testing.T) { | ||||
| 	wrr := New(WithFilter(filter.Version("v2.0.0"))) | ||||
| 	wrr := New() | ||||
| 	var nodes []selector.Node | ||||
| 	nodes = append(nodes, selector.NewNode( | ||||
| 		"http", | ||||
| @@ -32,7 +32,7 @@ func TestWrr(t *testing.T) { | ||||
| 	wrr.Apply(nodes) | ||||
| 	var count1, count2 int | ||||
| 	for i := 0; i < 90; i++ { | ||||
| 		n, done, err := wrr.Select(context.Background()) | ||||
| 		n, done, err := wrr.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0"))) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expect no error, got %v", err) | ||||
| 		} | ||||
|   | ||||
| @@ -1,59 +1,45 @@ | ||||
| package grpc | ||||
|  | ||||
| import ( | ||||
| 	"sync" | ||||
|  | ||||
| 	"github.com/go-kratos/kratos/v2/registry" | ||||
| 	"github.com/go-kratos/kratos/v2/selector" | ||||
| 	"github.com/go-kratos/kratos/v2/selector/p2c" | ||||
| 	"github.com/go-kratos/kratos/v2/selector/random" | ||||
| 	"github.com/go-kratos/kratos/v2/selector/wrr" | ||||
| 	"github.com/go-kratos/kratos/v2/transport" | ||||
|  | ||||
| 	gBalancer "google.golang.org/grpc/balancer" | ||||
| 	"google.golang.org/grpc/balancer" | ||||
| 	"google.golang.org/grpc/balancer/base" | ||||
| 	"google.golang.org/grpc/metadata" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	_ base.PickerBuilder = &Builder{} | ||||
| 	_ gBalancer.Picker   = &Picker{} | ||||
| const ( | ||||
| 	balancerName = "selector" | ||||
| ) | ||||
|  | ||||
| 	mu sync.Mutex | ||||
| var ( | ||||
| 	_ base.PickerBuilder = &balancerBuilder{} | ||||
| 	_ balancer.Picker    = &balancerPicker{} | ||||
| ) | ||||
|  | ||||
| func init() { | ||||
| 	// inject global grpc balancer | ||||
| 	SetGlobalBalancer(random.Name, random.NewBuilder()) | ||||
| 	SetGlobalBalancer(wrr.Name, wrr.NewBuilder()) | ||||
| 	SetGlobalBalancer(p2c.Name, p2c.NewBuilder()) | ||||
| } | ||||
|  | ||||
| // SetGlobalBalancer set grpc balancer with scheme. | ||||
| func SetGlobalBalancer(scheme string, builder selector.Builder) { | ||||
| 	mu.Lock() | ||||
| 	defer mu.Unlock() | ||||
|  | ||||
| 	b := base.NewBalancerBuilder( | ||||
| 		scheme, | ||||
| 		&Builder{builder: builder}, | ||||
| 		balancerName, | ||||
| 		&balancerBuilder{ | ||||
| 			builder: selector.GlobalSelector(), | ||||
| 		}, | ||||
| 		base.Config{HealthCheck: true}, | ||||
| 	) | ||||
| 	gBalancer.Register(b) | ||||
| 	balancer.Register(b) | ||||
| } | ||||
|  | ||||
| // Builder is grpc balancer builder. | ||||
| type Builder struct { | ||||
| type balancerBuilder struct { | ||||
| 	builder selector.Builder | ||||
| } | ||||
|  | ||||
| // Build creates a grpc Picker. | ||||
| func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker { | ||||
| func (b *balancerBuilder) Build(info base.PickerBuildInfo) balancer.Picker { | ||||
| 	if len(info.ReadySCs) == 0 { | ||||
| 		// Block the RPC until a new picker is available via UpdateState(). | ||||
| 		return base.NewErrPicker(gBalancer.ErrNoSubConnAvailable) | ||||
| 		return base.NewErrPicker(balancer.ErrNoSubConnAvailable) | ||||
| 	} | ||||
|  | ||||
| 	nodes := make([]selector.Node, 0) | ||||
| 	for conn, info := range info.ReadySCs { | ||||
| 		ins, _ := info.Address.Attributes.Value("rawServiceInstance").(*registry.ServiceInstance) | ||||
| @@ -62,40 +48,40 @@ func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker { | ||||
| 			subConn: conn, | ||||
| 		}) | ||||
| 	} | ||||
| 	p := &Picker{ | ||||
| 	p := &balancerPicker{ | ||||
| 		selector: b.builder.Build(), | ||||
| 	} | ||||
| 	p.selector.Apply(nodes) | ||||
| 	return p | ||||
| } | ||||
|  | ||||
| // Picker is a grpc picker. | ||||
| type Picker struct { | ||||
| // balancerPicker is a grpc picker. | ||||
| type balancerPicker struct { | ||||
| 	selector selector.Selector | ||||
| } | ||||
|  | ||||
| // Pick pick instances. | ||||
| func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) { | ||||
| 	var filters []selector.Filter | ||||
| func (p *balancerPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { | ||||
| 	var filters []selector.NodeFilter | ||||
| 	if tr, ok := transport.FromClientContext(info.Ctx); ok { | ||||
| 		if gtr, ok := tr.(*Transport); ok { | ||||
| 			filters = gtr.SelectFilters() | ||||
| 			filters = gtr.NodeFilters() | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	n, done, err := p.selector.Select(info.Ctx, selector.WithFilter(filters...)) | ||||
| 	n, done, err := p.selector.Select(info.Ctx, selector.WithNodeFilter(filters...)) | ||||
| 	if err != nil { | ||||
| 		return gBalancer.PickResult{}, err | ||||
| 		return balancer.PickResult{}, err | ||||
| 	} | ||||
|  | ||||
| 	return gBalancer.PickResult{ | ||||
| 	return balancer.PickResult{ | ||||
| 		SubConn: n.(*grpcNode).subConn, | ||||
| 		Done: func(di gBalancer.DoneInfo) { | ||||
| 		Done: func(di balancer.DoneInfo) { | ||||
| 			done(info.Ctx, selector.DoneInfo{ | ||||
| 				Err:           di.Err, | ||||
| 				BytesSent:     di.BytesSent, | ||||
| 				BytesReceived: di.BytesReceived, | ||||
| 				ReplyMeta:     Trailer(di.Trailer), | ||||
| 				ReplyMD:       Trailer(di.Trailer), | ||||
| 			}) | ||||
| 		}, | ||||
| 	}, nil | ||||
| @@ -115,5 +101,5 @@ func (t Trailer) Get(k string) string { | ||||
|  | ||||
| type grpcNode struct { | ||||
| 	selector.Node | ||||
| 	subConn gBalancer.SubConn | ||||
| 	subConn balancer.SubConn | ||||
| } | ||||
|   | ||||
| @@ -19,19 +19,10 @@ func TestTrailer(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestBalancerName(t *testing.T) { | ||||
| 	o := &clientOptions{} | ||||
|  | ||||
| 	WithBalancerName("p2c")(o) | ||||
| 	if !reflect.DeepEqual("p2c", o.balancerName) { | ||||
| 		t.Errorf("expect %v, got %v", "p2c", o.balancerName) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestFilters(t *testing.T) { | ||||
| 	o := &clientOptions{} | ||||
|  | ||||
| 	WithFilter(func(_ context.Context, nodes []selector.Node) []selector.Node { | ||||
| 	WithNodeFilter(func(_ context.Context, nodes []selector.Node) []selector.Node { | ||||
| 		return nodes | ||||
| 	})(o) | ||||
| 	if !reflect.DeepEqual(1, len(o.filters)) { | ||||
|   | ||||
| @@ -10,7 +10,7 @@ import ( | ||||
| 	"github.com/go-kratos/kratos/v2/middleware" | ||||
| 	"github.com/go-kratos/kratos/v2/registry" | ||||
| 	"github.com/go-kratos/kratos/v2/selector" | ||||
| 	"github.com/go-kratos/kratos/v2/selector/wrr" | ||||
| 	"github.com/go-kratos/kratos/v2/selector/p2c" | ||||
| 	"github.com/go-kratos/kratos/v2/transport" | ||||
| 	"github.com/go-kratos/kratos/v2/transport/grpc/resolver/discovery" | ||||
|  | ||||
| @@ -23,6 +23,12 @@ import ( | ||||
| 	grpcmd "google.golang.org/grpc/metadata" | ||||
| ) | ||||
|  | ||||
| func init() { | ||||
| 	if selector.GlobalSelector() == nil { | ||||
| 		selector.SetGlobalSelector(p2c.NewBuilder()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ClientOption is gRPC client option. | ||||
| type ClientOption func(o *clientOptions) | ||||
|  | ||||
| @@ -75,15 +81,8 @@ func WithOptions(opts ...grpc.DialOption) ClientOption { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithBalancerName with balancer name | ||||
| func WithBalancerName(name string) ClientOption { | ||||
| 	return func(o *clientOptions) { | ||||
| 		o.balancerName = name | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithFilter with select filters | ||||
| func WithFilter(filters ...selector.Filter) ClientOption { | ||||
| // WithNodeFilter with select filters | ||||
| func WithNodeFilter(filters ...selector.NodeFilter) ClientOption { | ||||
| 	return func(o *clientOptions) { | ||||
| 		o.filters = filters | ||||
| 	} | ||||
| @@ -105,7 +104,7 @@ type clientOptions struct { | ||||
| 	ints         []grpc.UnaryClientInterceptor | ||||
| 	grpcOpts     []grpc.DialOption | ||||
| 	balancerName string | ||||
| 	filters      []selector.Filter | ||||
| 	filters      []selector.NodeFilter | ||||
| } | ||||
|  | ||||
| // Dial returns a GRPC connection. | ||||
| @@ -121,7 +120,7 @@ func DialInsecure(ctx context.Context, opts ...ClientOption) (*grpc.ClientConn, | ||||
| func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.ClientConn, error) { | ||||
| 	options := clientOptions{ | ||||
| 		timeout:      2000 * time.Millisecond, | ||||
| 		balancerName: wrr.Name, | ||||
| 		balancerName: balancerName, | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| @@ -156,13 +155,13 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien | ||||
| 	return grpc.DialContext(ctx, options.endpoint, grpcOpts...) | ||||
| } | ||||
|  | ||||
| func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.Filter) grpc.UnaryClientInterceptor { | ||||
| func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.NodeFilter) grpc.UnaryClientInterceptor { | ||||
| 	return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { | ||||
| 		ctx = transport.NewClientContext(ctx, &Transport{ | ||||
| 			endpoint:  cc.Target(), | ||||
| 			operation: method, | ||||
| 			reqHeader: headerCarrier{}, | ||||
| 			filters:   filters, | ||||
| 			endpoint:    cc.Target(), | ||||
| 			operation:   method, | ||||
| 			reqHeader:   headerCarrier{}, | ||||
| 			nodeFilters: filters, | ||||
| 		}) | ||||
| 		if timeout > 0 { | ||||
| 			var cancel context.CancelFunc | ||||
|   | ||||
| @@ -14,7 +14,7 @@ type Transport struct { | ||||
| 	operation   string | ||||
| 	reqHeader   headerCarrier | ||||
| 	replyHeader headerCarrier | ||||
| 	filters     []selector.Filter | ||||
| 	nodeFilters []selector.NodeFilter | ||||
| } | ||||
|  | ||||
| // Kind returns the transport kind. | ||||
| @@ -42,9 +42,9 @@ func (tr *Transport) ReplyHeader() transport.Header { | ||||
| 	return tr.replyHeader | ||||
| } | ||||
|  | ||||
| // SelectFilters returns the client select filters. | ||||
| func (tr *Transport) SelectFilters() []selector.Filter { | ||||
| 	return tr.filters | ||||
| // NodeFilters returns the client select filters. | ||||
| func (tr *Transport) NodeFilters() []selector.NodeFilter { | ||||
| 	return tr.nodeFilters | ||||
| } | ||||
|  | ||||
| type headerCarrier metadata.MD | ||||
|   | ||||
| @@ -16,10 +16,16 @@ import ( | ||||
| 	"github.com/go-kratos/kratos/v2/middleware" | ||||
| 	"github.com/go-kratos/kratos/v2/registry" | ||||
| 	"github.com/go-kratos/kratos/v2/selector" | ||||
| 	"github.com/go-kratos/kratos/v2/selector/wrr" | ||||
| 	"github.com/go-kratos/kratos/v2/selector/p2c" | ||||
| 	"github.com/go-kratos/kratos/v2/transport" | ||||
| ) | ||||
|  | ||||
| func init() { | ||||
| 	if selector.GlobalSelector() == nil { | ||||
| 		selector.SetGlobalSelector(p2c.NewBuilder()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // DecodeErrorFunc is decode error func. | ||||
| type DecodeErrorFunc func(ctx context.Context, res *http.Response) error | ||||
|  | ||||
| @@ -43,7 +49,7 @@ type clientOptions struct { | ||||
| 	decoder      DecodeResponseFunc | ||||
| 	errorDecoder DecodeErrorFunc | ||||
| 	transport    http.RoundTripper | ||||
| 	selector     selector.Selector | ||||
| 	nodeFilters  []selector.NodeFilter | ||||
| 	discovery    registry.Discovery | ||||
| 	middleware   []middleware.Middleware | ||||
| 	block        bool | ||||
| @@ -112,10 +118,10 @@ func WithDiscovery(d registry.Discovery) ClientOption { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithSelector with client selector. | ||||
| func WithSelector(selector selector.Selector) ClientOption { | ||||
| // WithNodeFilter with select filters | ||||
| func WithNodeFilter(filters ...selector.NodeFilter) ClientOption { | ||||
| 	return func(o *clientOptions) { | ||||
| 		o.selector = selector | ||||
| 		o.nodeFilters = filters | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -140,6 +146,7 @@ type Client struct { | ||||
| 	r        *resolver | ||||
| 	cc       *http.Client | ||||
| 	insecure bool | ||||
| 	selector selector.Selector | ||||
| } | ||||
|  | ||||
| // NewClient returns an HTTP client. | ||||
| @@ -151,7 +158,6 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { | ||||
| 		decoder:      DefaultResponseDecoder, | ||||
| 		errorDecoder: DefaultErrorDecoder, | ||||
| 		transport:    http.DefaultTransport, | ||||
| 		selector:     wrr.New(), | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| @@ -166,10 +172,11 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	selector := selector.GlobalSelector().Build() | ||||
| 	var r *resolver | ||||
| 	if options.discovery != nil { | ||||
| 		if target.Scheme == "discovery" { | ||||
| 			if r, err = newResolver(ctx, options.discovery, target, options.selector, options.block, insecure); err != nil { | ||||
| 			if r, err = newResolver(ctx, options.discovery, target, selector, options.block, insecure); err != nil { | ||||
| 				return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint) | ||||
| 			} | ||||
| 		} else if _, _, err := host.ExtractHostPort(options.endpoint); err != nil { | ||||
| @@ -185,6 +192,7 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { | ||||
| 			Timeout:   options.timeout, | ||||
| 			Transport: options.transport, | ||||
| 		}, | ||||
| 		selector: selector, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @@ -276,7 +284,7 @@ func (client *Client) do(req *http.Request) (*http.Response, error) { | ||||
| 			err  error | ||||
| 			node selector.Node | ||||
| 		) | ||||
| 		if node, done, err = client.opts.selector.Select(req.Context()); err != nil { | ||||
| 		if node, done, err = client.selector.Select(req.Context(), selector.WithNodeFilter(client.opts.nodeFilters...)); err != nil { | ||||
| 			return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) | ||||
| 		} | ||||
| 		if client.insecure { | ||||
|   | ||||
| @@ -182,13 +182,18 @@ func TestWithDiscovery(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestWithSelector(t *testing.T) { | ||||
| 	ov := &selector.Default{} | ||||
| 	o := WithSelector(ov) | ||||
| func TestWithNodeFilter(t *testing.T) { | ||||
| 	ov := func(context.Context, []selector.Node) []selector.Node { | ||||
| 		return []selector.Node{&selector.DefaultNode{}} | ||||
| 	} | ||||
| 	o := WithNodeFilter(ov) | ||||
| 	co := &clientOptions{} | ||||
| 	o(co) | ||||
| 	if !reflect.DeepEqual(co.selector, ov) { | ||||
| 		t.Errorf("expected selector to be %v, got %v", ov, co.selector) | ||||
| 	for _, n := range co.nodeFilters { | ||||
| 		ret := n(context.Background(), nil) | ||||
| 		if len(ret) != 1 { | ||||
| 			t.Errorf("expected node  length to be 1, got %v", len(ret)) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user