mirror of
https://github.com/go-kratos/kratos.git
synced 2025-03-17 21:07:54 +02:00
refactor: unify selector filter (#2277)
* unify selector Co-authored-by: caoguoliang01 <caoguoliang01@bilibili.com> Co-authored-by: chenzhihui <zhihui_chen@foxmail.com>
This commit is contained in:
parent
d11c6892b4
commit
11cd43e3c3
@ -9,7 +9,6 @@ import (
|
||||
type Default struct {
|
||||
NodeBuilder WeightedNodeBuilder
|
||||
Balancer Balancer
|
||||
Filters []Filter
|
||||
|
||||
nodes atomic.Value
|
||||
}
|
||||
@ -27,16 +26,13 @@ func (d *Default) Select(ctx context.Context, opts ...SelectOption) (selected No
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
}
|
||||
if len(d.Filters) > 0 || len(options.Filters) > 0 {
|
||||
if len(options.NodeFilters) > 0 {
|
||||
newNodes := make([]Node, len(nodes))
|
||||
for i, wc := range nodes {
|
||||
newNodes[i] = wc
|
||||
}
|
||||
for _, f := range d.Filters {
|
||||
newNodes = f(ctx, newNodes)
|
||||
}
|
||||
for _, f := range options.Filters {
|
||||
newNodes = f(ctx, newNodes)
|
||||
for _, filter := range options.NodeFilters {
|
||||
newNodes = filter(ctx, newNodes)
|
||||
}
|
||||
candidates = make([]WeightedNode, len(newNodes))
|
||||
for i, n := range newNodes {
|
||||
@ -74,7 +70,6 @@ func (d *Default) Apply(nodes []Node) {
|
||||
type DefaultBuilder struct {
|
||||
Node WeightedNodeBuilder
|
||||
Balancer BalancerBuilder
|
||||
Filters []Filter
|
||||
}
|
||||
|
||||
// Build create builder
|
||||
@ -82,6 +77,5 @@ func (db *DefaultBuilder) Build() Selector {
|
||||
return &Default{
|
||||
NodeBuilder: db.Node,
|
||||
Balancer: db.Balancer.Build(),
|
||||
Filters: db.Filters,
|
||||
}
|
||||
}
|
||||
|
@ -2,5 +2,5 @@ package selector
|
||||
|
||||
import "context"
|
||||
|
||||
// Filter is select filter.
|
||||
type Filter func(context.Context, []Node) []Node
|
||||
// NodeFilter is select filter.
|
||||
type NodeFilter func(context.Context, []Node) []Node
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
// Version is version filter.
|
||||
func Version(version string) selector.Filter {
|
||||
func Version(version string) selector.NodeFilter {
|
||||
return func(_ context.Context, nodes []selector.Node) []selector.Node {
|
||||
newNodes := nodes[:0]
|
||||
for _, n := range nodes {
|
||||
|
13
selector/global.go
Normal file
13
selector/global.go
Normal file
@ -0,0 +1,13 @@
|
||||
package selector
|
||||
|
||||
var globalSelector Builder
|
||||
|
||||
// GlobalSelector returns global selector builder.
|
||||
func GlobalSelector() Builder {
|
||||
return globalSelector
|
||||
}
|
||||
|
||||
// SetGlobalSelector set global selector builder.
|
||||
func SetGlobalSelector(builder Builder) {
|
||||
globalSelector = builder
|
||||
}
|
@ -2,15 +2,15 @@ package selector
|
||||
|
||||
// SelectOptions is Select Options.
|
||||
type SelectOptions struct {
|
||||
Filters []Filter
|
||||
NodeFilters []NodeFilter
|
||||
}
|
||||
|
||||
// SelectOption is Selector option.
|
||||
type SelectOption func(*SelectOptions)
|
||||
|
||||
// WithFilter with filter options
|
||||
func WithFilter(fn ...Filter) SelectOption {
|
||||
// WithNodeFilter with filter options
|
||||
func WithNodeFilter(fn ...NodeFilter) SelectOption {
|
||||
return func(opts *SelectOptions) {
|
||||
opts.Filters = fn
|
||||
opts.NodeFilters = fn
|
||||
}
|
||||
}
|
||||
|
@ -19,20 +19,11 @@ const (
|
||||
|
||||
var _ selector.Balancer = &Balancer{}
|
||||
|
||||
// WithFilter with select filters
|
||||
func WithFilter(filters ...selector.Filter) Option {
|
||||
return func(o *options) {
|
||||
o.filters = filters
|
||||
}
|
||||
}
|
||||
|
||||
// Option is random builder option.
|
||||
type Option func(o *options)
|
||||
|
||||
// options is random builder options
|
||||
type options struct {
|
||||
filters []selector.Filter
|
||||
}
|
||||
type options struct{}
|
||||
|
||||
// New creates a p2c selector.
|
||||
func New(opts ...Option) selector.Selector {
|
||||
@ -95,7 +86,6 @@ func NewBuilder(opts ...Option) selector.Builder {
|
||||
opt(&option)
|
||||
}
|
||||
return &selector.DefaultBuilder{
|
||||
Filters: option.filters,
|
||||
Balancer: &Builder{},
|
||||
Node: &ewma.Builder{},
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
func TestWrr3(t *testing.T) {
|
||||
p2c := New(WithFilter(filter.Version("v2.0.0")))
|
||||
p2c := New()
|
||||
var nodes []selector.Node
|
||||
for i := 0; i < 3; i++ {
|
||||
addr := fmt.Sprintf("127.0.0.%d:8080", i)
|
||||
@ -41,7 +41,7 @@ func TestWrr3(t *testing.T) {
|
||||
d := time.Duration(rand.Intn(500)) * time.Millisecond
|
||||
lk.Unlock()
|
||||
time.Sleep(d)
|
||||
n, done, err := p2c.Select(context.Background())
|
||||
n, done, err := p2c.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
|
||||
if err != nil {
|
||||
t.Errorf("expect %v, got %v", nil, err)
|
||||
}
|
||||
@ -92,7 +92,7 @@ func TestEmpty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOne(t *testing.T) {
|
||||
p2c := New(WithFilter(filter.Version("v2.0.0")))
|
||||
p2c := New()
|
||||
var nodes []selector.Node
|
||||
for i := 0; i < 1; i++ {
|
||||
addr := fmt.Sprintf("127.0.0.%d:8080", i)
|
||||
@ -106,7 +106,7 @@ func TestOne(t *testing.T) {
|
||||
}))
|
||||
}
|
||||
p2c.Apply(nodes)
|
||||
n, done, err := p2c.Select(context.Background())
|
||||
n, done, err := p2c.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
|
||||
if err != nil {
|
||||
t.Errorf("expect %v, got %v", nil, err)
|
||||
}
|
||||
|
@ -15,20 +15,11 @@ const (
|
||||
|
||||
var _ selector.Balancer = &Balancer{} // Name is balancer name
|
||||
|
||||
// WithFilter with select filters
|
||||
func WithFilter(filters ...selector.Filter) Option {
|
||||
return func(o *options) {
|
||||
o.filters = filters
|
||||
}
|
||||
}
|
||||
|
||||
// Option is random builder option.
|
||||
type Option func(o *options)
|
||||
|
||||
// options is random builder options
|
||||
type options struct {
|
||||
filters []selector.Filter
|
||||
}
|
||||
type options struct{}
|
||||
|
||||
// Balancer is a random balancer.
|
||||
type Balancer struct{}
|
||||
@ -56,7 +47,6 @@ func NewBuilder(opts ...Option) selector.Builder {
|
||||
opt(&option)
|
||||
}
|
||||
return &selector.DefaultBuilder{
|
||||
Filters: option.filters,
|
||||
Balancer: &Builder{},
|
||||
Node: &direct.Builder{},
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func TestWrr(t *testing.T) {
|
||||
random := New(WithFilter(filter.Version("v2.0.0")))
|
||||
random := New()
|
||||
var nodes []selector.Node
|
||||
nodes = append(nodes, selector.NewNode(
|
||||
"http",
|
||||
@ -31,7 +31,7 @@ func TestWrr(t *testing.T) {
|
||||
random.Apply(nodes)
|
||||
var count1, count2 int
|
||||
for i := 0; i < 200; i++ {
|
||||
n, done, err := random.Select(context.Background())
|
||||
n, done, err := random.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ type DoneInfo struct {
|
||||
// Response Error
|
||||
Err error
|
||||
// Response Metadata
|
||||
ReplyMeta ReplyMeta
|
||||
ReplyMD ReplyMD
|
||||
|
||||
// BytesSent indicates if any bytes have been sent to the server.
|
||||
BytesSent bool
|
||||
@ -65,8 +65,8 @@ type DoneInfo struct {
|
||||
BytesReceived bool
|
||||
}
|
||||
|
||||
// ReplyMeta is Reply Metadata.
|
||||
type ReplyMeta interface {
|
||||
// ReplyMD is Reply Metadata.
|
||||
type ReplyMD interface {
|
||||
Get(key string) string
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ func (b *mockWeightedNodeBuilder) Build(n Node) WeightedNode {
|
||||
return &mockWeightedNode{Node: n}
|
||||
}
|
||||
|
||||
func mockFilter(version string) Filter {
|
||||
func mockFilter(version string) NodeFilter {
|
||||
return func(_ context.Context, nodes []Node) []Node {
|
||||
newNodes := nodes[:0]
|
||||
for _, n := range nodes {
|
||||
@ -83,7 +83,6 @@ func (b *mockBalancer) Pick(ctx context.Context, nodes []WeightedNode) (selected
|
||||
func TestDefault(t *testing.T) {
|
||||
builder := DefaultBuilder{
|
||||
Node: &mockWeightedNodeBuilder{},
|
||||
Filters: []Filter{mockFilter("v2.0.0")},
|
||||
Balancer: &mockBalancerBuilder{},
|
||||
}
|
||||
selector := builder.Build()
|
||||
@ -109,7 +108,7 @@ func TestDefault(t *testing.T) {
|
||||
Metadata: map[string]string{"weight": "10"},
|
||||
}))
|
||||
selector.Apply(nodes)
|
||||
n, done, err := selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
|
||||
n, done, err := selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0")))
|
||||
if err != nil {
|
||||
t.Errorf("expect %v, got %v", nil, err)
|
||||
}
|
||||
@ -137,7 +136,7 @@ func TestDefault(t *testing.T) {
|
||||
done(context.Background(), DoneInfo{})
|
||||
|
||||
// no v3.0.0 instance
|
||||
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v3.0.0")))
|
||||
n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v3.0.0")))
|
||||
if !errors.Is(ErrNoAvailable, err) {
|
||||
t.Errorf("expect %v, got %v", ErrNoAvailable, err)
|
||||
}
|
||||
@ -150,7 +149,7 @@ func TestDefault(t *testing.T) {
|
||||
|
||||
// apply zero instance
|
||||
selector.Apply([]Node{})
|
||||
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
|
||||
n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0")))
|
||||
if !errors.Is(ErrNoAvailable, err) {
|
||||
t.Errorf("expect %v, got %v", ErrNoAvailable, err)
|
||||
}
|
||||
@ -163,7 +162,7 @@ func TestDefault(t *testing.T) {
|
||||
|
||||
// apply zero instance
|
||||
selector.Apply(nil)
|
||||
n, done, err = selector.Select(context.Background(), WithFilter(mockFilter("v2.0.0")))
|
||||
n, done, err = selector.Select(context.Background(), WithNodeFilter(mockFilter("v2.0.0")))
|
||||
if !errors.Is(ErrNoAvailable, err) {
|
||||
t.Errorf("expect %v, got %v", ErrNoAvailable, err)
|
||||
}
|
||||
|
@ -15,20 +15,11 @@ const (
|
||||
|
||||
var _ selector.Balancer = &Balancer{} // Name is balancer name
|
||||
|
||||
// WithFilter with select filters
|
||||
func WithFilter(filters ...selector.Filter) Option {
|
||||
return func(o *options) {
|
||||
o.filters = filters
|
||||
}
|
||||
}
|
||||
|
||||
// Option is random builder option.
|
||||
type Option func(o *options)
|
||||
|
||||
// options is random builder options
|
||||
type options struct {
|
||||
filters []selector.Filter
|
||||
}
|
||||
type options struct{}
|
||||
|
||||
// Balancer is a random balancer.
|
||||
type Balancer struct {
|
||||
@ -77,7 +68,6 @@ func NewBuilder(opts ...Option) selector.Builder {
|
||||
opt(&option)
|
||||
}
|
||||
return &selector.DefaultBuilder{
|
||||
Filters: option.filters,
|
||||
Balancer: &Builder{},
|
||||
Node: &direct.Builder{},
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func TestWrr(t *testing.T) {
|
||||
wrr := New(WithFilter(filter.Version("v2.0.0")))
|
||||
wrr := New()
|
||||
var nodes []selector.Node
|
||||
nodes = append(nodes, selector.NewNode(
|
||||
"http",
|
||||
@ -32,7 +32,7 @@ func TestWrr(t *testing.T) {
|
||||
wrr.Apply(nodes)
|
||||
var count1, count2 int
|
||||
for i := 0; i < 90; i++ {
|
||||
n, done, err := wrr.Select(context.Background())
|
||||
n, done, err := wrr.Select(context.Background(), selector.WithNodeFilter(filter.Version("v2.0.0")))
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
|
@ -1,59 +1,45 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/registry"
|
||||
"github.com/go-kratos/kratos/v2/selector"
|
||||
"github.com/go-kratos/kratos/v2/selector/p2c"
|
||||
"github.com/go-kratos/kratos/v2/selector/random"
|
||||
"github.com/go-kratos/kratos/v2/selector/wrr"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
|
||||
gBalancer "google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/base"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
var (
|
||||
_ base.PickerBuilder = &Builder{}
|
||||
_ gBalancer.Picker = &Picker{}
|
||||
const (
|
||||
balancerName = "selector"
|
||||
)
|
||||
|
||||
mu sync.Mutex
|
||||
var (
|
||||
_ base.PickerBuilder = &balancerBuilder{}
|
||||
_ balancer.Picker = &balancerPicker{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
// inject global grpc balancer
|
||||
SetGlobalBalancer(random.Name, random.NewBuilder())
|
||||
SetGlobalBalancer(wrr.Name, wrr.NewBuilder())
|
||||
SetGlobalBalancer(p2c.Name, p2c.NewBuilder())
|
||||
}
|
||||
|
||||
// SetGlobalBalancer set grpc balancer with scheme.
|
||||
func SetGlobalBalancer(scheme string, builder selector.Builder) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
b := base.NewBalancerBuilder(
|
||||
scheme,
|
||||
&Builder{builder: builder},
|
||||
balancerName,
|
||||
&balancerBuilder{
|
||||
builder: selector.GlobalSelector(),
|
||||
},
|
||||
base.Config{HealthCheck: true},
|
||||
)
|
||||
gBalancer.Register(b)
|
||||
balancer.Register(b)
|
||||
}
|
||||
|
||||
// Builder is grpc balancer builder.
|
||||
type Builder struct {
|
||||
type balancerBuilder struct {
|
||||
builder selector.Builder
|
||||
}
|
||||
|
||||
// Build creates a grpc Picker.
|
||||
func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker {
|
||||
func (b *balancerBuilder) Build(info base.PickerBuildInfo) balancer.Picker {
|
||||
if len(info.ReadySCs) == 0 {
|
||||
// Block the RPC until a new picker is available via UpdateState().
|
||||
return base.NewErrPicker(gBalancer.ErrNoSubConnAvailable)
|
||||
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
|
||||
}
|
||||
|
||||
nodes := make([]selector.Node, 0)
|
||||
for conn, info := range info.ReadySCs {
|
||||
ins, _ := info.Address.Attributes.Value("rawServiceInstance").(*registry.ServiceInstance)
|
||||
@ -62,40 +48,40 @@ func (b *Builder) Build(info base.PickerBuildInfo) gBalancer.Picker {
|
||||
subConn: conn,
|
||||
})
|
||||
}
|
||||
p := &Picker{
|
||||
p := &balancerPicker{
|
||||
selector: b.builder.Build(),
|
||||
}
|
||||
p.selector.Apply(nodes)
|
||||
return p
|
||||
}
|
||||
|
||||
// Picker is a grpc picker.
|
||||
type Picker struct {
|
||||
// balancerPicker is a grpc picker.
|
||||
type balancerPicker struct {
|
||||
selector selector.Selector
|
||||
}
|
||||
|
||||
// Pick pick instances.
|
||||
func (p *Picker) Pick(info gBalancer.PickInfo) (gBalancer.PickResult, error) {
|
||||
var filters []selector.Filter
|
||||
func (p *balancerPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
|
||||
var filters []selector.NodeFilter
|
||||
if tr, ok := transport.FromClientContext(info.Ctx); ok {
|
||||
if gtr, ok := tr.(*Transport); ok {
|
||||
filters = gtr.SelectFilters()
|
||||
filters = gtr.NodeFilters()
|
||||
}
|
||||
}
|
||||
|
||||
n, done, err := p.selector.Select(info.Ctx, selector.WithFilter(filters...))
|
||||
n, done, err := p.selector.Select(info.Ctx, selector.WithNodeFilter(filters...))
|
||||
if err != nil {
|
||||
return gBalancer.PickResult{}, err
|
||||
return balancer.PickResult{}, err
|
||||
}
|
||||
|
||||
return gBalancer.PickResult{
|
||||
return balancer.PickResult{
|
||||
SubConn: n.(*grpcNode).subConn,
|
||||
Done: func(di gBalancer.DoneInfo) {
|
||||
Done: func(di balancer.DoneInfo) {
|
||||
done(info.Ctx, selector.DoneInfo{
|
||||
Err: di.Err,
|
||||
BytesSent: di.BytesSent,
|
||||
BytesReceived: di.BytesReceived,
|
||||
ReplyMeta: Trailer(di.Trailer),
|
||||
ReplyMD: Trailer(di.Trailer),
|
||||
})
|
||||
},
|
||||
}, nil
|
||||
@ -115,5 +101,5 @@ func (t Trailer) Get(k string) string {
|
||||
|
||||
type grpcNode struct {
|
||||
selector.Node
|
||||
subConn gBalancer.SubConn
|
||||
subConn balancer.SubConn
|
||||
}
|
||||
|
@ -19,19 +19,10 @@ func TestTrailer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBalancerName(t *testing.T) {
|
||||
o := &clientOptions{}
|
||||
|
||||
WithBalancerName("p2c")(o)
|
||||
if !reflect.DeepEqual("p2c", o.balancerName) {
|
||||
t.Errorf("expect %v, got %v", "p2c", o.balancerName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilters(t *testing.T) {
|
||||
o := &clientOptions{}
|
||||
|
||||
WithFilter(func(_ context.Context, nodes []selector.Node) []selector.Node {
|
||||
WithNodeFilter(func(_ context.Context, nodes []selector.Node) []selector.Node {
|
||||
return nodes
|
||||
})(o)
|
||||
if !reflect.DeepEqual(1, len(o.filters)) {
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/registry"
|
||||
"github.com/go-kratos/kratos/v2/selector"
|
||||
"github.com/go-kratos/kratos/v2/selector/wrr"
|
||||
"github.com/go-kratos/kratos/v2/selector/p2c"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
"github.com/go-kratos/kratos/v2/transport/grpc/resolver/discovery"
|
||||
|
||||
@ -23,6 +23,12 @@ import (
|
||||
grpcmd "google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if selector.GlobalSelector() == nil {
|
||||
selector.SetGlobalSelector(p2c.NewBuilder())
|
||||
}
|
||||
}
|
||||
|
||||
// ClientOption is gRPC client option.
|
||||
type ClientOption func(o *clientOptions)
|
||||
|
||||
@ -75,15 +81,8 @@ func WithOptions(opts ...grpc.DialOption) ClientOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithBalancerName with balancer name
|
||||
func WithBalancerName(name string) ClientOption {
|
||||
return func(o *clientOptions) {
|
||||
o.balancerName = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithFilter with select filters
|
||||
func WithFilter(filters ...selector.Filter) ClientOption {
|
||||
// WithNodeFilter with select filters
|
||||
func WithNodeFilter(filters ...selector.NodeFilter) ClientOption {
|
||||
return func(o *clientOptions) {
|
||||
o.filters = filters
|
||||
}
|
||||
@ -105,7 +104,7 @@ type clientOptions struct {
|
||||
ints []grpc.UnaryClientInterceptor
|
||||
grpcOpts []grpc.DialOption
|
||||
balancerName string
|
||||
filters []selector.Filter
|
||||
filters []selector.NodeFilter
|
||||
}
|
||||
|
||||
// Dial returns a GRPC connection.
|
||||
@ -121,7 +120,7 @@ func DialInsecure(ctx context.Context, opts ...ClientOption) (*grpc.ClientConn,
|
||||
func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.ClientConn, error) {
|
||||
options := clientOptions{
|
||||
timeout: 2000 * time.Millisecond,
|
||||
balancerName: wrr.Name,
|
||||
balancerName: balancerName,
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
@ -156,13 +155,13 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien
|
||||
return grpc.DialContext(ctx, options.endpoint, grpcOpts...)
|
||||
}
|
||||
|
||||
func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.Filter) grpc.UnaryClientInterceptor {
|
||||
func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, filters []selector.NodeFilter) grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
ctx = transport.NewClientContext(ctx, &Transport{
|
||||
endpoint: cc.Target(),
|
||||
operation: method,
|
||||
reqHeader: headerCarrier{},
|
||||
filters: filters,
|
||||
endpoint: cc.Target(),
|
||||
operation: method,
|
||||
reqHeader: headerCarrier{},
|
||||
nodeFilters: filters,
|
||||
})
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
|
@ -14,7 +14,7 @@ type Transport struct {
|
||||
operation string
|
||||
reqHeader headerCarrier
|
||||
replyHeader headerCarrier
|
||||
filters []selector.Filter
|
||||
nodeFilters []selector.NodeFilter
|
||||
}
|
||||
|
||||
// Kind returns the transport kind.
|
||||
@ -42,9 +42,9 @@ func (tr *Transport) ReplyHeader() transport.Header {
|
||||
return tr.replyHeader
|
||||
}
|
||||
|
||||
// SelectFilters returns the client select filters.
|
||||
func (tr *Transport) SelectFilters() []selector.Filter {
|
||||
return tr.filters
|
||||
// NodeFilters returns the client select filters.
|
||||
func (tr *Transport) NodeFilters() []selector.NodeFilter {
|
||||
return tr.nodeFilters
|
||||
}
|
||||
|
||||
type headerCarrier metadata.MD
|
||||
|
@ -16,10 +16,16 @@ import (
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/registry"
|
||||
"github.com/go-kratos/kratos/v2/selector"
|
||||
"github.com/go-kratos/kratos/v2/selector/wrr"
|
||||
"github.com/go-kratos/kratos/v2/selector/p2c"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if selector.GlobalSelector() == nil {
|
||||
selector.SetGlobalSelector(p2c.NewBuilder())
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeErrorFunc is decode error func.
|
||||
type DecodeErrorFunc func(ctx context.Context, res *http.Response) error
|
||||
|
||||
@ -43,7 +49,7 @@ type clientOptions struct {
|
||||
decoder DecodeResponseFunc
|
||||
errorDecoder DecodeErrorFunc
|
||||
transport http.RoundTripper
|
||||
selector selector.Selector
|
||||
nodeFilters []selector.NodeFilter
|
||||
discovery registry.Discovery
|
||||
middleware []middleware.Middleware
|
||||
block bool
|
||||
@ -112,10 +118,10 @@ func WithDiscovery(d registry.Discovery) ClientOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelector with client selector.
|
||||
func WithSelector(selector selector.Selector) ClientOption {
|
||||
// WithNodeFilter with select filters
|
||||
func WithNodeFilter(filters ...selector.NodeFilter) ClientOption {
|
||||
return func(o *clientOptions) {
|
||||
o.selector = selector
|
||||
o.nodeFilters = filters
|
||||
}
|
||||
}
|
||||
|
||||
@ -140,6 +146,7 @@ type Client struct {
|
||||
r *resolver
|
||||
cc *http.Client
|
||||
insecure bool
|
||||
selector selector.Selector
|
||||
}
|
||||
|
||||
// NewClient returns an HTTP client.
|
||||
@ -151,7 +158,6 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
|
||||
decoder: DefaultResponseDecoder,
|
||||
errorDecoder: DefaultErrorDecoder,
|
||||
transport: http.DefaultTransport,
|
||||
selector: wrr.New(),
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
@ -166,10 +172,11 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selector := selector.GlobalSelector().Build()
|
||||
var r *resolver
|
||||
if options.discovery != nil {
|
||||
if target.Scheme == "discovery" {
|
||||
if r, err = newResolver(ctx, options.discovery, target, options.selector, options.block, insecure); err != nil {
|
||||
if r, err = newResolver(ctx, options.discovery, target, selector, options.block, insecure); err != nil {
|
||||
return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint)
|
||||
}
|
||||
} else if _, _, err := host.ExtractHostPort(options.endpoint); err != nil {
|
||||
@ -185,6 +192,7 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
|
||||
Timeout: options.timeout,
|
||||
Transport: options.transport,
|
||||
},
|
||||
selector: selector,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -276,7 +284,7 @@ func (client *Client) do(req *http.Request) (*http.Response, error) {
|
||||
err error
|
||||
node selector.Node
|
||||
)
|
||||
if node, done, err = client.opts.selector.Select(req.Context()); err != nil {
|
||||
if node, done, err = client.selector.Select(req.Context(), selector.WithNodeFilter(client.opts.nodeFilters...)); err != nil {
|
||||
return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error())
|
||||
}
|
||||
if client.insecure {
|
||||
|
@ -182,13 +182,18 @@ func TestWithDiscovery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSelector(t *testing.T) {
|
||||
ov := &selector.Default{}
|
||||
o := WithSelector(ov)
|
||||
func TestWithNodeFilter(t *testing.T) {
|
||||
ov := func(context.Context, []selector.Node) []selector.Node {
|
||||
return []selector.Node{&selector.DefaultNode{}}
|
||||
}
|
||||
o := WithNodeFilter(ov)
|
||||
co := &clientOptions{}
|
||||
o(co)
|
||||
if !reflect.DeepEqual(co.selector, ov) {
|
||||
t.Errorf("expected selector to be %v, got %v", ov, co.selector)
|
||||
for _, n := range co.nodeFilters {
|
||||
ret := n(context.Background(), nil)
|
||||
if len(ret) != 1 {
|
||||
t.Errorf("expected node length to be 1, got %v", len(ret))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user