diff --git a/examples/helloworld/client/main.go b/examples/helloworld/client/main.go index 77c1950eb..ba7b4fbdb 100644 --- a/examples/helloworld/client/main.go +++ b/examples/helloworld/client/main.go @@ -24,7 +24,7 @@ func callHTTP() { recovery.Recovery(), ), transhttp.WithEndpoint("127.0.0.1:8000"), - transhttp.WithSchema("http"), + transhttp.WithScheme("http"), ) if err != nil { log.Fatal(err) diff --git a/examples/registry/consul/client/main.go b/examples/registry/consul/client/main.go index df9f70cda..0f18399d4 100644 --- a/examples/registry/consul/client/main.go +++ b/examples/registry/consul/client/main.go @@ -3,10 +3,13 @@ package main import ( "context" "log" + "time" "github.com/go-kratos/consul/registry" "github.com/go-kratos/kratos/examples/helloworld/helloworld" + "github.com/go-kratos/kratos/v2/middleware/recovery" "github.com/go-kratos/kratos/v2/transport/grpc" + transhttp "github.com/go-kratos/kratos/v2/transport/http" "github.com/hashicorp/consul/api" ) @@ -15,6 +18,11 @@ func main() { if err != nil { panic(err) } + callHTTP(cli) + callGRPC(cli) +} + +func callGRPC(cli *api.Client) { r := registry.New(cli) conn, err := grpc.DialInsecure( context.Background(), @@ -25,9 +33,33 @@ func main() { log.Fatal(err) } client := helloworld.NewGreeterClient(conn) - reply, err := client.SayHello(context.Background(), &helloworld.HelloRequest{Name: "kratos"}) + reply, err := client.SayHello(context.Background(), &helloworld.HelloRequest{Name: "kratos_grpc"}) if err != nil { log.Fatal(err) } log.Printf("[grpc] SayHello %+v\n", reply) } + +func callHTTP(cli *api.Client) { + r := registry.New(cli) + conn, err := transhttp.NewClient( + context.Background(), + transhttp.WithMiddleware( + recovery.Recovery(), + ), + transhttp.WithScheme("http"), + transhttp.WithEndpoint("discovery:///helloworld"), + transhttp.WithDiscovery(r), + ) + if err != nil { + log.Fatal(err) + } + time.Sleep(time.Millisecond * 250) + client := helloworld.NewGreeterHttpClient(conn) + reply, err := client.SayHello(context.Background(), &helloworld.HelloRequest{Name: "kratos_http"}) + if err != nil { + log.Fatal(err) + } + log.Printf("[http] SayHello %s\n", reply.Message) + +} diff --git a/examples/registry/consul/server/main.go b/examples/registry/consul/server/main.go index cadb31d0a..133f3b512 100644 --- a/examples/registry/consul/server/main.go +++ b/examples/registry/consul/server/main.go @@ -3,12 +3,16 @@ package main import ( "context" "fmt" - "log" + "os" "github.com/go-kratos/consul/registry" pb "github.com/go-kratos/kratos/examples/helloworld/helloworld" "github.com/go-kratos/kratos/v2" + "github.com/go-kratos/kratos/v2/log" + "github.com/go-kratos/kratos/v2/middleware/logging" + "github.com/go-kratos/kratos/v2/middleware/recovery" "github.com/go-kratos/kratos/v2/transport/grpc" + "github.com/go-kratos/kratos/v2/transport/http" "github.com/hashicorp/consul/api" ) @@ -23,27 +27,41 @@ func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloRe } func main() { - grpcSrv := grpc.NewServer( - grpc.Address(":9000"), - ) - - s := &server{} - pb.RegisterGreeterServer(grpcSrv, s) - - cli, err := api.NewClient(api.DefaultConfig()) + logger := log.NewStdLogger(os.Stdout) + log := log.NewHelper(logger) + consulClient, err := api.NewClient(api.DefaultConfig()) if err != nil { panic(err) } - r := registry.New(cli) + + grpcSrv := grpc.NewServer( + grpc.Address(":9000"), + grpc.Middleware( + recovery.Recovery(), + logging.Server(logger), + ), + ) + s := &server{} + pb.RegisterGreeterServer(grpcSrv, s) + + httpSrv := http.NewServer(http.Address(":8000")) + httpSrv.HandlePrefix("/", pb.NewGreeterHandler(s, + http.Middleware( + recovery.Recovery(), + )), + ) + + r := registry.New(consulClient) app := kratos.New( kratos.Name("helloworld"), kratos.Server( grpcSrv, + httpSrv, ), kratos.Registrar(r), ) if err := app.Run(); err != nil { - log.Fatal(err) + log.Errorf("app run failed:%v", err) } } diff --git a/internal/balancer/balancer.go b/internal/balancer/balancer.go new file mode 100644 index 000000000..3c79632a1 --- /dev/null +++ b/internal/balancer/balancer.go @@ -0,0 +1,18 @@ +package balancer + +import ( + "context" + + "github.com/go-kratos/kratos/v2/registry" +) + +// DoneInfo is callback when rpc done +type DoneInfo struct { + Err error + Trailer map[string]string +} + +// Balancer is node pick balancer +type Balancer interface { + Pick(ctx context.Context, pathPattern string, nodes []*registry.ServiceInstance) (node *registry.ServiceInstance, done func(DoneInfo), err error) +} diff --git a/internal/balancer/random/random.go b/internal/balancer/random/random.go new file mode 100644 index 000000000..293dbd66a --- /dev/null +++ b/internal/balancer/random/random.go @@ -0,0 +1,29 @@ +package random + +import ( + "context" + "fmt" + "math/rand" + + "github.com/go-kratos/kratos/v2/internal/balancer" + "github.com/go-kratos/kratos/v2/registry" +) + +var _ balancer.Balancer = &Balancer{} + +type Balancer struct { +} + +func New() *Balancer { + return &Balancer{} +} + +func (b *Balancer) Pick(ctx context.Context, pathPattern string, nodes []*registry.ServiceInstance) (node *registry.ServiceInstance, done func(balancer.DoneInfo), err error) { + if len(nodes) == 0 { + return nil, nil, fmt.Errorf("no instances avaiable") + } else if len(nodes) == 1 { + return nodes[0], func(di balancer.DoneInfo) {}, nil + } + idx := rand.Intn(len(nodes)) + return nodes[idx], func(di balancer.DoneInfo) {}, nil +} diff --git a/transport/http/client.go b/transport/http/client.go index 3ef20e58c..7cce5eab0 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -7,28 +7,40 @@ import ( "io" "io/ioutil" "net/http" + "net/url" "time" "github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/internal/balancer" + "github.com/go-kratos/kratos/v2/internal/balancer/random" "github.com/go-kratos/kratos/v2/internal/httputil" "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/transport" ) // Client is http client type Client struct { cc *http.Client + r *resolver + b balancer.Balancer - schema string - endpoint string + scheme string + target Target userAgent string middleware middleware.Middleware encoder EncodeRequestFunc decoder DecodeResponseFunc errorDecoder DecodeErrorFunc + discovery registry.Discovery } +const ( + // errNodeNotFound represents service node not found. + errNodeNotFound = "NODE_NOT_FOUND" +) + // DecodeErrorFunc is decode error func. type DecodeErrorFunc func(ctx context.Context, res *http.Response) error @@ -69,10 +81,10 @@ func WithMiddleware(m ...middleware.Middleware) ClientOption { } } -// WithSchema with client schema. -func WithSchema(schema string) ClientOption { +// WithScheme with client schema. +func WithScheme(scheme string) ClientOption { return func(o *clientOptions) { - o.schema = schema + o.scheme = scheme } } @@ -104,43 +116,94 @@ func WithErrorDecoder(errorDecoder DecodeErrorFunc) ClientOption { } } +// WithDiscovery with client discovery. +func WithDiscovery(d registry.Discovery) ClientOption { + return func(o *clientOptions) { + o.discovery = d + } +} + +// WithBalancer with client balancer. +// Experimental +// Notice: This type is EXPERIMENTAL and may be changed or removed in a later release. +func WithBalancer(b balancer.Balancer) ClientOption { + return func(o *clientOptions) { + o.balancer = b + } +} + // Client is a HTTP transport client. type clientOptions struct { ctx context.Context transport http.RoundTripper middleware middleware.Middleware timeout time.Duration - schema string + scheme string endpoint string userAgent string encoder EncodeRequestFunc decoder DecodeResponseFunc errorDecoder DecodeErrorFunc + discovery registry.Discovery + balancer balancer.Balancer } // NewClient returns an HTTP client. func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { options := &clientOptions{ ctx: ctx, - schema: "http", + scheme: "http", timeout: 1 * time.Second, encoder: defaultRequestEncoder, decoder: defaultResponseDecoder, errorDecoder: defaultErrorDecoder, transport: http.DefaultTransport, + discovery: nil, + balancer: random.New(), } for _, o := range opts { o(options) } + target := Target{ + Scheme: options.scheme, + Endpoint: options.endpoint, + } + var r *resolver + if options.endpoint != "" && options.discovery != nil { + u, err := url.Parse(options.endpoint) + if err != nil { + u, err = url.Parse("http://" + options.endpoint) + if err != nil { + return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint) + } + } + if u.Scheme == "discovery" && len(u.Path) > 1 { + target = Target{ + Scheme: u.Scheme, + Authority: u.Host, + Endpoint: u.Path[1:], + } + r, err = newResolver(ctx, options.scheme, options.discovery, target) + if err != nil { + return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint) + } + } else { + return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint) + } + } + return &Client{ cc: &http.Client{Timeout: options.timeout, Transport: options.transport}, + r: r, encoder: options.encoder, decoder: options.decoder, errorDecoder: options.errorDecoder, middleware: options.middleware, userAgent: options.userAgent, - endpoint: options.endpoint, - schema: options.schema, + target: target, + scheme: options.scheme, + discovery: options.discovery, + b: options.balancer, }, nil } @@ -169,7 +232,7 @@ func (client *Client) Invoke(ctx context.Context, path string, args interface{}, } reqBody = bytes.NewReader(body) } - url := fmt.Sprintf("%s://%s%s", client.schema, client.endpoint, path) + url := fmt.Sprintf("%s://%s%s", client.scheme, client.target.Endpoint, path) req, err := http.NewRequest(c.method, url, reqBody) if err != nil { return err @@ -192,7 +255,29 @@ func (client *Client) Invoke(ctx context.Context, path string, args interface{}, func (client *Client) invoke(ctx context.Context, req *http.Request, args interface{}, reply interface{}, c callInfo) error { h := func(ctx context.Context, in interface{}) (interface{}, error) { + var done func(balancer.DoneInfo) + if client.r != nil { + nodes := client.r.fetch(ctx) + if len(nodes) == 0 { + return nil, errors.ServiceUnavailable(errNodeNotFound, "fetch error") + } + var node *registry.ServiceInstance + var err error + node, done, err = client.b.Pick(ctx, c.pathPattern, nodes) + if err != nil { + return nil, errors.ServiceUnavailable(errNodeNotFound, err.Error()) + } + req = req.Clone(ctx) + addr, err := parseEndpoint(client.scheme, node.Endpoints) + if err != nil { + return nil, errors.ServiceUnavailable(errNodeNotFound, err.Error()) + } + req.URL.Host = addr + } res, err := client.do(ctx, req, c) + if done != nil { + done(balancer.DoneInfo{Err: err}) + } if err != nil { return nil, err } diff --git a/transport/http/resovler.go b/transport/http/resovler.go new file mode 100644 index 000000000..2f605da49 --- /dev/null +++ b/transport/http/resovler.go @@ -0,0 +1,86 @@ +package http + +import ( + "context" + "net/url" + "sync" + + "github.com/go-kratos/kratos/v2/log" + "github.com/go-kratos/kratos/v2/registry" +) + +// Target is resolver target +type Target struct { + Scheme string + Authority string + Endpoint string +} + +type resolver struct { + lock sync.RWMutex + nodes []*registry.ServiceInstance + + target Target + watcher registry.Watcher + logger *log.Helper +} + +func newResolver(ctx context.Context, scheme string, discovery registry.Discovery, target Target) (*resolver, error) { + watcher, err := discovery.Watch(ctx, target.Endpoint) + if err != nil { + return nil, err + } + r := &resolver{ + target: target, + watcher: watcher, + logger: log.NewHelper(log.DefaultLogger), + } + go func() { + for { + services, err := watcher.Next() + if err != nil { + r.logger.Errorf("http client watch services got unexpected error:=%v", err) + return + } + var nodes []*registry.ServiceInstance + for _, in := range services { + endpoint, err := parseEndpoint(scheme, in.Endpoints) + if err != nil { + r.logger.Errorf("Failed to parse discovery endpoint: %v error %v", in.Endpoints, err) + continue + } + if endpoint == "" { + continue + } + nodes = append(nodes, in) + } + if len(nodes) != 0 { + r.lock.Lock() + r.nodes = nodes + r.lock.Unlock() + } + } + }() + return r, nil +} + +func (r *resolver) fetch(ctx context.Context) []*registry.ServiceInstance { + r.lock.RLock() + nodes := r.nodes + r.lock.RUnlock() + + return nodes +} + +func parseEndpoint(schema string, endpoints []string) (string, error) { + for _, e := range endpoints { + u, err := url.Parse(e) + if err != nil { + return "", err + } + if u.Scheme == schema { + return u.Host, nil + } + } + return "", nil +}