diff --git a/internal/balancer/balancer.go b/transport/http/balancer/balancer.go similarity index 76% rename from internal/balancer/balancer.go rename to transport/http/balancer/balancer.go index 3c79632a1..db2529549 100644 --- a/internal/balancer/balancer.go +++ b/transport/http/balancer/balancer.go @@ -14,5 +14,5 @@ type DoneInfo struct { // Balancer is node pick balancer type Balancer interface { - Pick(ctx context.Context, pathPattern string, nodes []*registry.ServiceInstance) (node *registry.ServiceInstance, done func(DoneInfo), err error) + Pick(ctx context.Context, pathPattern string, nodes []*registry.ServiceInstance) (node *registry.ServiceInstance, done func(context.Context, DoneInfo), err error) } diff --git a/internal/balancer/random/random.go b/transport/http/balancer/random/random.go similarity index 65% rename from internal/balancer/random/random.go rename to transport/http/balancer/random/random.go index 293dbd66a..f56e66dfd 100644 --- a/internal/balancer/random/random.go +++ b/transport/http/balancer/random/random.go @@ -5,8 +5,8 @@ import ( "fmt" "math/rand" - "github.com/go-kratos/kratos/v2/internal/balancer" "github.com/go-kratos/kratos/v2/registry" + "github.com/go-kratos/kratos/v2/transport/http/balancer" ) var _ balancer.Balancer = &Balancer{} @@ -18,12 +18,12 @@ func New() *Balancer { return &Balancer{} } -func (b *Balancer) Pick(ctx context.Context, pathPattern string, nodes []*registry.ServiceInstance) (node *registry.ServiceInstance, done func(balancer.DoneInfo), err error) { +func (b *Balancer) Pick(ctx context.Context, pathPattern string, nodes []*registry.ServiceInstance) (node *registry.ServiceInstance, done func(context.Context, balancer.DoneInfo), err error) { if len(nodes) == 0 { return nil, nil, fmt.Errorf("no instances avaiable") } else if len(nodes) == 1 { - return nodes[0], func(di balancer.DoneInfo) {}, nil + return nodes[0], func(context.Context, balancer.DoneInfo) {}, nil } idx := rand.Intn(len(nodes)) - return nodes[idx], func(di balancer.DoneInfo) {}, nil + return nodes[idx], func(context.Context, balancer.DoneInfo) {}, nil } diff --git a/transport/http/client.go b/transport/http/client.go index 7cce5eab0..21bf5c852 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -12,12 +12,12 @@ import ( "github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/errors" - "github.com/go-kratos/kratos/v2/internal/balancer" - "github.com/go-kratos/kratos/v2/internal/balancer/random" "github.com/go-kratos/kratos/v2/internal/httputil" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/transport" + "github.com/go-kratos/kratos/v2/transport/http/balancer" + "github.com/go-kratos/kratos/v2/transport/http/balancer/random" ) // Client is http client @@ -255,7 +255,7 @@ func (client *Client) Invoke(ctx context.Context, path string, args interface{}, func (client *Client) invoke(ctx context.Context, req *http.Request, args interface{}, reply interface{}, c callInfo) error { h := func(ctx context.Context, in interface{}) (interface{}, error) { - var done func(balancer.DoneInfo) + var done func(context.Context, balancer.DoneInfo) if client.r != nil { nodes := client.r.fetch(ctx) if len(nodes) == 0 { @@ -276,7 +276,7 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf } res, err := client.do(ctx, req, c) if done != nil { - done(balancer.DoneInfo{Err: err}) + done(ctx, balancer.DoneInfo{Err: err}) } if err != nil { return nil, err