diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index c31625b..2702ef2 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -10,6 +10,10 @@ import ( // It represents a single RPC method. type Endpoint func(ctx context.Context, request interface{}) (response interface{}, err error) +// Nop is an endpoint that does nothing and returns a nil error. +// Useful for tests. +func Nop(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } + // Middleware is a chainable behavior modifier for endpoints. type Middleware func(Endpoint) Endpoint diff --git a/sd/cache/benchmark_test.go b/sd/cache/benchmark_test.go new file mode 100644 index 0000000..41f1821 --- /dev/null +++ b/sd/cache/benchmark_test.go @@ -0,0 +1,29 @@ +package cache + +import ( + "io" + "testing" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" +) + +func BenchmarkEndpoints(b *testing.B) { + var ( + ca = make(closer) + cb = make(closer) + cmap = map[string]io.Closer{"a": ca, "b": cb} + factory = func(instance string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, cmap[instance], nil } + c = New(factory, log.NewNopLogger()) + ) + + b.ReportAllocs() + + c.Update([]string{"a", "b"}) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c.Endpoints() + } + }) +} diff --git a/sd/cache/cache.go b/sd/cache/cache.go new file mode 100644 index 0000000..ab4ab7d --- /dev/null +++ b/sd/cache/cache.go @@ -0,0 +1,97 @@ +package cache + +import ( + "io" + "sort" + "sync" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" +) + +// Cache collects the most recent set of endpoints from a service discovery +// system via a subscriber, and makes them available to consumers. Cache is +// meant to be embedded inside of a concrete subscriber, and can serve Service +// invocations directly. +type Cache struct { + mtx sync.RWMutex + factory sd.Factory + cache map[string]endpointCloser + slice []endpoint.Endpoint + logger log.Logger +} + +type endpointCloser struct { + endpoint.Endpoint + io.Closer +} + +// New returns a new, empty endpoint cache. +func New(factory sd.Factory, logger log.Logger) *Cache { + return &Cache{ + factory: factory, + cache: map[string]endpointCloser{}, + logger: logger, + } +} + +// Update should be invoked by clients with a complete set of current instance +// strings whenever that set changes. The cache manufactures new endpoints via +// the factory, closes old endpoints when they disappear, and persists existing +// endpoints if they survive through an update. +func (c *Cache) Update(instances []string) { + c.mtx.Lock() + defer c.mtx.Unlock() + + // Deterministic order (for later). + sort.Strings(instances) + + // Produce the current set of services. + cache := make(map[string]endpointCloser, len(instances)) + for _, instance := range instances { + // If it already exists, just copy it over. + if sc, ok := c.cache[instance]; ok { + cache[instance] = sc + delete(c.cache, instance) + continue + } + + // If it doesn't exist, create it. + service, closer, err := c.factory(instance) + if err != nil { + c.logger.Log("instance", instance, "err", err) + continue + } + cache[instance] = endpointCloser{service, closer} + } + + // Close any leftover endpoints. + for _, sc := range c.cache { + if sc.Closer != nil { + sc.Closer.Close() + } + } + + // Populate the slice of endpoints. + slice := make([]endpoint.Endpoint, 0, len(cache)) + for _, instance := range instances { + // A bad factory may mean an instance is not present. + if _, ok := cache[instance]; !ok { + continue + } + slice = append(slice, cache[instance].Endpoint) + } + + // Swap and trigger GC for old copies. + c.slice = slice + c.cache = cache +} + +// Endpoints yields the current set of (presumably identical) endpoints, ordered +// lexicographically by the corresponding instance string. +func (c *Cache) Endpoints() []endpoint.Endpoint { + c.mtx.RLock() + defer c.mtx.RUnlock() + return c.slice +} diff --git a/sd/cache/cache_test.go b/sd/cache/cache_test.go new file mode 100644 index 0000000..be9abaf --- /dev/null +++ b/sd/cache/cache_test.go @@ -0,0 +1,91 @@ +package cache + +import ( + "errors" + "io" + "testing" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" +) + +func TestCache(t *testing.T) { + var ( + ca = make(closer) + cb = make(closer) + c = map[string]io.Closer{"a": ca, "b": cb} + f = func(instance string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, c[instance], nil } + cache = New(f, log.NewNopLogger()) + ) + + // Populate + cache.Update([]string{"a", "b"}) + select { + case <-ca: + t.Errorf("endpoint a closed, not good") + case <-cb: + t.Errorf("endpoint b closed, not good") + case <-time.After(time.Millisecond): + t.Logf("no closures yet, good") + } + if want, have := 2, len(cache.Endpoints()); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Duplicate, should be no-op + cache.Update([]string{"a", "b"}) + select { + case <-ca: + t.Errorf("endpoint a closed, not good") + case <-cb: + t.Errorf("endpoint b closed, not good") + case <-time.After(time.Millisecond): + t.Logf("no closures yet, good") + } + if want, have := 2, len(cache.Endpoints()); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Delete b + go cache.Update([]string{"a"}) + select { + case <-ca: + t.Errorf("endpoint a closed, not good") + case <-cb: + t.Logf("endpoint b closed, good") + case <-time.After(time.Second): + t.Errorf("didn't close the deleted instance in time") + } + if want, have := 1, len(cache.Endpoints()); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Delete a + go cache.Update([]string{}) + select { + // case <-cb: will succeed, as it's closed + case <-ca: + t.Logf("endpoint a closed, good") + case <-time.After(time.Second): + t.Errorf("didn't close the deleted instance in time") + } + if want, have := 0, len(cache.Endpoints()); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestBadFactory(t *testing.T) { + cache := New(func(string) (endpoint.Endpoint, io.Closer, error) { + return nil, nil, errors.New("bad factory") + }, log.NewNopLogger()) + + cache.Update([]string{"foo:1234", "bar:5678"}) + if want, have := 0, len(cache.Endpoints()); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +type closer chan struct{} + +func (c closer) Close() error { close(c); return nil } diff --git a/sd/consul/client.go b/sd/consul/client.go new file mode 100644 index 0000000..4d88ce3 --- /dev/null +++ b/sd/consul/client.go @@ -0,0 +1,37 @@ +package consul + +import consul "github.com/hashicorp/consul/api" + +// Client is a wrapper around the Consul API. +type Client interface { + // Register a service with the local agent. + Register(r *consul.AgentServiceRegistration) error + + // Deregister a service with the local agent. + Deregister(r *consul.AgentServiceRegistration) error + + // Service + Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) +} + +type client struct { + consul *consul.Client +} + +// NewClient returns an implementation of the Client interface, wrapping a +// concrete Consul client. +func NewClient(c *consul.Client) Client { + return &client{consul: c} +} + +func (c *client) Register(r *consul.AgentServiceRegistration) error { + return c.consul.Agent().ServiceRegister(r) +} + +func (c *client) Deregister(r *consul.AgentServiceRegistration) error { + return c.consul.Agent().ServiceDeregister(r.ID) +} + +func (c *client) Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) { + return c.consul.Health().Service(service, tag, passingOnly, queryOpts) +} diff --git a/sd/consul/client_test.go b/sd/consul/client_test.go new file mode 100644 index 0000000..cf02aea --- /dev/null +++ b/sd/consul/client_test.go @@ -0,0 +1,156 @@ +package consul + +import ( + "errors" + "io" + "reflect" + "testing" + + stdconsul "github.com/hashicorp/consul/api" + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" +) + +func TestClientRegistration(t *testing.T) { + c := newTestClient(nil) + + services, _, err := c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{}) + if err != nil { + t.Error(err) + } + if want, have := 0, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } + + if err := c.Register(testRegistration); err != nil { + t.Error(err) + } + + if err := c.Register(testRegistration); err == nil { + t.Errorf("want error, have %v", err) + } + + services, _, err = c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{}) + if err != nil { + t.Error(err) + } + if want, have := 1, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } + + if err := c.Deregister(testRegistration); err != nil { + t.Error(err) + } + + if err := c.Deregister(testRegistration); err == nil { + t.Errorf("want error, have %v", err) + } + + services, _, err = c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{}) + if err != nil { + t.Error(err) + } + if want, have := 0, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +type testClient struct { + entries []*stdconsul.ServiceEntry +} + +func newTestClient(entries []*stdconsul.ServiceEntry) *testClient { + return &testClient{ + entries: entries, + } +} + +var _ Client = &testClient{} + +func (c *testClient) Service(service, tag string, _ bool, opts *stdconsul.QueryOptions) ([]*stdconsul.ServiceEntry, *stdconsul.QueryMeta, error) { + var results []*stdconsul.ServiceEntry + + for _, entry := range c.entries { + if entry.Service.Service != service { + continue + } + if tag != "" { + tagMap := map[string]struct{}{} + + for _, t := range entry.Service.Tags { + tagMap[t] = struct{}{} + } + + if _, ok := tagMap[tag]; !ok { + continue + } + } + + results = append(results, entry) + } + + return results, &stdconsul.QueryMeta{}, nil +} + +func (c *testClient) Register(r *stdconsul.AgentServiceRegistration) error { + toAdd := registration2entry(r) + + for _, entry := range c.entries { + if reflect.DeepEqual(*entry, *toAdd) { + return errors.New("duplicate") + } + } + + c.entries = append(c.entries, toAdd) + return nil +} + +func (c *testClient) Deregister(r *stdconsul.AgentServiceRegistration) error { + toDelete := registration2entry(r) + + var newEntries []*stdconsul.ServiceEntry + for _, entry := range c.entries { + if reflect.DeepEqual(*entry, *toDelete) { + continue + } + newEntries = append(newEntries, entry) + } + if len(newEntries) == len(c.entries) { + return errors.New("not found") + } + + c.entries = newEntries + return nil +} + +func registration2entry(r *stdconsul.AgentServiceRegistration) *stdconsul.ServiceEntry { + return &stdconsul.ServiceEntry{ + Node: &stdconsul.Node{ + Node: "some-node", + Address: r.Address, + }, + Service: &stdconsul.AgentService{ + ID: r.ID, + Service: r.Name, + Tags: r.Tags, + Port: r.Port, + Address: r.Address, + }, + // Checks ignored + } +} + +func testFactory(instance string) (endpoint.Endpoint, io.Closer, error) { + return func(context.Context, interface{}) (interface{}, error) { + return instance, nil + }, nil, nil +} + +var testRegistration = &stdconsul.AgentServiceRegistration{ + ID: "my-id", + Name: "my-name", + Tags: []string{"my-tag-1", "my-tag-2"}, + Port: 12345, + Address: "my-address", +} diff --git a/sd/consul/integration_test.go b/sd/consul/integration_test.go new file mode 100644 index 0000000..495adad --- /dev/null +++ b/sd/consul/integration_test.go @@ -0,0 +1,86 @@ +// +build integration + +package consul + +import ( + "io" + "os" + "testing" + "time" + + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/service" + stdconsul "github.com/hashicorp/consul/api" +) + +func TestIntegration(t *testing.T) { + // Connect to Consul. + // docker run -p 8500:8500 progrium/consul -server -bootstrap + consulAddr := os.Getenv("CONSUL_ADDRESS") + if consulAddr == "" { + t.Fatal("CONSUL_ADDRESS is not set") + } + stdClient, err := stdconsul.NewClient(&stdconsul.Config{ + Address: consulAddr, + }) + if err != nil { + t.Fatal(err) + } + client := NewClient(stdClient) + logger := log.NewLogfmtLogger(os.Stderr) + + // Produce a fake service registration. + r := &stdconsul.AgentServiceRegistration{ + ID: "my-service-ID", + Name: "my-service-name", + Tags: []string{"alpha", "beta"}, + Port: 12345, + Address: "my-address", + EnableTagOverride: false, + // skipping check(s) + } + + // Build a subscriber on r.Name + r.Tags. + factory := func(instance string) (service.Service, io.Closer, error) { + t.Logf("factory invoked for %q", instance) + return service.Fixed{}, nil, nil + } + subscriber, err := NewSubscriber( + client, + factory, + log.NewContext(logger).With("component", "subscriber"), + r.Name, + r.Tags, + true, + ) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Second) + + // Before we publish, we should have no services. + services, err := subscriber.Services() + if err != nil { + t.Error(err) + } + if want, have := 0, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Build a registrar for r. + registrar := NewRegistrar(client, r, log.NewContext(logger).With("component", "registrar")) + registrar.Register() + defer registrar.Deregister() + + time.Sleep(time.Second) + + // Now we should have one active service. + services, err = subscriber.Services() + if err != nil { + t.Error(err) + } + if want, have := 1, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } +} diff --git a/sd/consul/registrar.go b/sd/consul/registrar.go new file mode 100644 index 0000000..e89fef6 --- /dev/null +++ b/sd/consul/registrar.go @@ -0,0 +1,44 @@ +package consul + +import ( + "fmt" + + stdconsul "github.com/hashicorp/consul/api" + + "github.com/go-kit/kit/log" +) + +// Registrar registers service instance liveness information to Consul. +type Registrar struct { + client Client + registration *stdconsul.AgentServiceRegistration + logger log.Logger +} + +// NewRegistrar returns a Consul Registrar acting on the provided catalog +// registration. +func NewRegistrar(client Client, r *stdconsul.AgentServiceRegistration, logger log.Logger) *Registrar { + return &Registrar{ + client: client, + registration: r, + logger: log.NewContext(logger).With("service", r.Name, "tags", fmt.Sprint(r.Tags), "address", r.Address), + } +} + +// Register implements sd.Registrar interface. +func (p *Registrar) Register() { + if err := p.client.Register(p.registration); err != nil { + p.logger.Log("err", err) + } else { + p.logger.Log("action", "register") + } +} + +// Deregister implements sd.Registrar interface. +func (p *Registrar) Deregister() { + if err := p.client.Deregister(p.registration); err != nil { + p.logger.Log("err", err) + } else { + p.logger.Log("action", "deregister") + } +} diff --git a/sd/consul/registrar_test.go b/sd/consul/registrar_test.go new file mode 100644 index 0000000..edc7723 --- /dev/null +++ b/sd/consul/registrar_test.go @@ -0,0 +1,27 @@ +package consul + +import ( + "testing" + + stdconsul "github.com/hashicorp/consul/api" + + "github.com/go-kit/kit/log" +) + +func TestRegistrar(t *testing.T) { + client := newTestClient([]*stdconsul.ServiceEntry{}) + p := NewRegistrar(client, testRegistration, log.NewNopLogger()) + if want, have := 0, len(client.entries); want != have { + t.Errorf("want %d, have %d", want, have) + } + + p.Register() + if want, have := 1, len(client.entries); want != have { + t.Errorf("want %d, have %d", want, have) + } + + p.Deregister() + if want, have := 0, len(client.entries); want != have { + t.Errorf("want %d, have %d", want, have) + } +} diff --git a/sd/consul/subscriber.go b/sd/consul/subscriber.go new file mode 100644 index 0000000..a2840dd --- /dev/null +++ b/sd/consul/subscriber.go @@ -0,0 +1,166 @@ +package consul + +import ( + "fmt" + "io" + + consul "github.com/hashicorp/consul/api" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" + "github.com/go-kit/kit/sd/cache" +) + +const defaultIndex = 0 + +// Subscriber yields endpoints for a service in Consul. Updates to the service +// are watched and will update the Subscriber endpoints. +type Subscriber struct { + cache *cache.Cache + client Client + logger log.Logger + service string + tags []string + passingOnly bool + endpointsc chan []endpoint.Endpoint + quitc chan struct{} +} + +var _ sd.Subscriber = &Subscriber{} + +// NewSubscriber returns a Consul subscriber which returns endpoints for the +// requested service. It only returns instances for which all of the passed tags +// are present. +func NewSubscriber(client Client, factory sd.Factory, logger log.Logger, service string, tags []string, passingOnly bool) (*Subscriber, error) { + s := &Subscriber{ + cache: cache.New(factory, logger), + client: client, + logger: log.NewContext(logger).With("service", service, "tags", fmt.Sprint(tags)), + service: service, + tags: tags, + passingOnly: passingOnly, + quitc: make(chan struct{}), + } + + instances, index, err := s.getInstances(defaultIndex, nil) + if err == nil { + s.logger.Log("instances", len(instances)) + } else { + s.logger.Log("err", err) + } + + s.cache.Update(instances) + go s.loop(index) + return s, nil +} + +// Endpoints implements the Subscriber interface. +func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { + return s.cache.Endpoints(), nil +} + +// Stop terminates the subscriber. +func (s *Subscriber) Stop() { + close(s.quitc) +} + +func (s *Subscriber) loop(lastIndex uint64) { + var ( + instances []string + err error + ) + for { + instances, lastIndex, err = s.getInstances(lastIndex, s.quitc) + switch { + case err == io.EOF: + return // stopped via quitc + case err != nil: + s.logger.Log("err", err) + default: + s.cache.Update(instances) + } + } +} + +func (s *Subscriber) getInstances(lastIndex uint64, interruptc chan struct{}) ([]string, uint64, error) { + tag := "" + if len(s.tags) > 0 { + tag = s.tags[0] + } + + // Consul doesn't support more than one tag in its service query method. + // https://github.com/hashicorp/consul/issues/294 + // Hashi suggest prepared queries, but they don't support blocking. + // https://www.consul.io/docs/agent/http/query.html#execute + // If we want blocking for efficiency, we must filter tags manually. + + type response struct { + instances []string + index uint64 + } + + var ( + errc = make(chan error, 1) + resc = make(chan response, 1) + ) + + go func() { + entries, meta, err := s.client.Service(s.service, tag, s.passingOnly, &consul.QueryOptions{ + WaitIndex: lastIndex, + }) + if err != nil { + errc <- err + return + } + if len(s.tags) > 1 { + entries = filterEntries(entries, s.tags[1:]...) + } + resc <- response{ + instances: makeInstances(entries), + index: meta.LastIndex, + } + }() + + select { + case err := <-errc: + return nil, 0, err + case res := <-resc: + return res.instances, res.index, nil + case <-interruptc: + return nil, 0, io.EOF + } +} + +func filterEntries(entries []*consul.ServiceEntry, tags ...string) []*consul.ServiceEntry { + var es []*consul.ServiceEntry + +ENTRIES: + for _, entry := range entries { + ts := make(map[string]struct{}, len(entry.Service.Tags)) + for _, tag := range entry.Service.Tags { + ts[tag] = struct{}{} + } + + for _, tag := range tags { + if _, ok := ts[tag]; !ok { + continue ENTRIES + } + } + es = append(es, entry) + } + + return es +} + +func makeInstances(entries []*consul.ServiceEntry) []string { + instances := make([]string, len(entries)) + for i, entry := range entries { + addr := entry.Node.Address + if entry.Service.Address != "" { + addr = entry.Service.Address + } + instances[i] = fmt.Sprintf("%s:%d", addr, entry.Service.Port) + } + return instances +} diff --git a/sd/consul/subscriber_test.go b/sd/consul/subscriber_test.go new file mode 100644 index 0000000..9be92bb --- /dev/null +++ b/sd/consul/subscriber_test.go @@ -0,0 +1,150 @@ +package consul + +import ( + "testing" + + consul "github.com/hashicorp/consul/api" + "golang.org/x/net/context" + + "github.com/go-kit/kit/log" +) + +var consulState = []*consul.ServiceEntry{ + { + Node: &consul.Node{ + Address: "10.0.0.0", + Node: "app00.local", + }, + Service: &consul.AgentService{ + ID: "search-api-0", + Port: 8000, + Service: "search", + Tags: []string{ + "api", + "v1", + }, + }, + }, + { + Node: &consul.Node{ + Address: "10.0.0.1", + Node: "app01.local", + }, + Service: &consul.AgentService{ + ID: "search-api-1", + Port: 8001, + Service: "search", + Tags: []string{ + "api", + "v2", + }, + }, + }, + { + Node: &consul.Node{ + Address: "10.0.0.1", + Node: "app01.local", + }, + Service: &consul.AgentService{ + Address: "10.0.0.10", + ID: "search-db-0", + Port: 9000, + Service: "search", + Tags: []string{ + "db", + }, + }, + }, +} + +func TestSubscriber(t *testing.T) { + var ( + logger = log.NewNopLogger() + client = newTestClient(consulState) + ) + + s, err := NewSubscriber(client, testFactory, logger, "search", []string{"api"}, true) + if err != nil { + t.Fatal(err) + } + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + + if want, have := 2, len(endpoints); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestSubscriberNoService(t *testing.T) { + var ( + logger = log.NewNopLogger() + client = newTestClient(consulState) + ) + + s, err := NewSubscriber(client, testFactory, logger, "feed", []string{}, true) + if err != nil { + t.Fatal(err) + } + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + + if want, have := 0, len(endpoints); want != have { + t.Fatalf("want %d, have %d", want, have) + } +} + +func TestSubscriberWithTags(t *testing.T) { + var ( + logger = log.NewNopLogger() + client = newTestClient(consulState) + ) + + s, err := NewSubscriber(client, testFactory, logger, "search", []string{"api", "v2"}, true) + if err != nil { + t.Fatal(err) + } + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + + if want, have := 1, len(endpoints); want != have { + t.Fatalf("want %d, have %d", want, have) + } +} + +func TestSubscriberAddressOverride(t *testing.T) { + s, err := NewSubscriber(newTestClient(consulState), testFactory, log.NewNopLogger(), "search", []string{"db"}, true) + if err != nil { + t.Fatal(err) + } + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + + if want, have := 1, len(endpoints); want != have { + t.Fatalf("want %d, have %d", want, have) + } + + response, err := endpoints[0](context.Background(), struct{}{}) + if err != nil { + t.Fatal(err) + } + + if want, have := "10.0.0.10:9000", response.(string); want != have { + t.Errorf("want %q, have %q", want, have) + } +} diff --git a/sd/dnssrv/lookup.go b/sd/dnssrv/lookup.go new file mode 100644 index 0000000..9d46ea6 --- /dev/null +++ b/sd/dnssrv/lookup.go @@ -0,0 +1,7 @@ +package dnssrv + +import "net" + +// Lookup is a function that resolves a DNS SRV record to multiple addresses. +// It has the same signature as net.LookupSRV. +type Lookup func(service, proto, name string) (cname string, addrs []*net.SRV, err error) diff --git a/sd/dnssrv/subscriber.go b/sd/dnssrv/subscriber.go new file mode 100644 index 0000000..422fdaa --- /dev/null +++ b/sd/dnssrv/subscriber.go @@ -0,0 +1,100 @@ +package dnssrv + +import ( + "fmt" + "net" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" + "github.com/go-kit/kit/sd/cache" +) + +// Subscriber yields endpoints taken from the named DNS SRV record. The name is +// resolved on a fixed schedule. Priorities and weights are ignored. +type Subscriber struct { + name string + cache *cache.Cache + logger log.Logger + quit chan struct{} +} + +// NewSubscriber returns a DNS SRV subscriber. +func NewSubscriber( + name string, + ttl time.Duration, + factory sd.Factory, + logger log.Logger, +) *Subscriber { + return NewSubscriberDetailed(name, time.NewTicker(ttl), net.LookupSRV, factory, logger) +} + +// NewSubscriberDetailed is the same as NewSubscriber, but allows users to +// provide an explicit lookup refresh ticker instead of a TTL, and specify the +// lookup function instead of using net.LookupSRV. +func NewSubscriberDetailed( + name string, + refresh *time.Ticker, + lookup Lookup, + factory sd.Factory, + logger log.Logger, +) *Subscriber { + p := &Subscriber{ + name: name, + cache: cache.New(factory, logger), + logger: logger, + quit: make(chan struct{}), + } + + instances, err := p.resolve(lookup) + if err == nil { + logger.Log("name", name, "instances", len(instances)) + } else { + logger.Log("name", name, "err", err) + } + p.cache.Update(instances) + + go p.loop(refresh, lookup) + return p +} + +// Stop terminates the Subscriber. +func (p *Subscriber) Stop() { + close(p.quit) +} + +func (p *Subscriber) loop(t *time.Ticker, lookup Lookup) { + defer t.Stop() + for { + select { + case <-t.C: + instances, err := p.resolve(lookup) + if err != nil { + p.logger.Log("name", p.name, "err", err) + continue // don't replace potentially-good with bad + } + p.cache.Update(instances) + + case <-p.quit: + return + } + } +} + +// Endpoints implements the Subscriber interface. +func (p *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { + return p.cache.Endpoints(), nil +} + +func (p *Subscriber) resolve(lookup Lookup) ([]string, error) { + _, addrs, err := lookup("", "", p.name) + if err != nil { + return []string{}, err + } + instances := make([]string, len(addrs)) + for i, addr := range addrs { + instances[i] = net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) + } + return instances, nil +} diff --git a/sd/dnssrv/subscriber_test.go b/sd/dnssrv/subscriber_test.go new file mode 100644 index 0000000..5a9036a --- /dev/null +++ b/sd/dnssrv/subscriber_test.go @@ -0,0 +1,85 @@ +package dnssrv + +import ( + "io" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" +) + +func TestRefresh(t *testing.T) { + name := "some.service.internal" + + ticker := time.NewTicker(time.Second) + ticker.Stop() + tickc := make(chan time.Time) + ticker.C = tickc + + var lookups uint64 + records := []*net.SRV{} + lookup := func(service, proto, name string) (string, []*net.SRV, error) { + t.Logf("lookup(%q, %q, %q)", service, proto, name) + atomic.AddUint64(&lookups, 1) + return "cname", records, nil + } + + var generates uint64 + factory := func(instance string) (endpoint.Endpoint, io.Closer, error) { + t.Logf("factory(%q)", instance) + atomic.AddUint64(&generates, 1) + return endpoint.Nop, nopCloser{}, nil + } + + subscriber := NewSubscriberDetailed(name, ticker, lookup, factory, log.NewNopLogger()) + defer subscriber.Stop() + + // First lookup, empty + endpoints, err := subscriber.Endpoints() + if err != nil { + t.Error(err) + } + if want, have := 0, len(endpoints); want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := uint64(1), atomic.LoadUint64(&lookups); want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := uint64(0), atomic.LoadUint64(&generates); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Load some records and lookup again + records = []*net.SRV{ + &net.SRV{Target: "1.0.0.1", Port: 1001}, + &net.SRV{Target: "1.0.0.2", Port: 1002}, + &net.SRV{Target: "1.0.0.3", Port: 1003}, + } + tickc <- time.Now() + + // There is a race condition where the subscriber.Endpoints call below + // invokes the cache before it is updated by the tick above. + // TODO(pb): solve by running the read through the loop goroutine. + time.Sleep(100 * time.Millisecond) + + endpoints, err = subscriber.Endpoints() + if err != nil { + t.Error(err) + } + if want, have := 3, len(endpoints); want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := uint64(2), atomic.LoadUint64(&lookups); want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := uint64(len(records)), atomic.LoadUint64(&generates); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +type nopCloser struct{} + +func (nopCloser) Close() error { return nil } diff --git a/sd/doc.go b/sd/doc.go new file mode 100644 index 0000000..b10d96f --- /dev/null +++ b/sd/doc.go @@ -0,0 +1,5 @@ +// Package sd provides utilities related to service discovery. That includes +// subscribing to service discovery systems in order to reach remote instances, +// and publishing to service discovery systems to make an instance available. +// Implementations are provided for most common systems. +package sd diff --git a/sd/etcd/client.go b/sd/etcd/client.go new file mode 100644 index 0000000..b9e2904 --- /dev/null +++ b/sd/etcd/client.go @@ -0,0 +1,131 @@ +package etcd + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "net/http" + "time" + + etcd "github.com/coreos/etcd/client" + "golang.org/x/net/context" +) + +// Client is a wrapper around the etcd client. +type Client interface { + // GetEntries will query the given prefix in etcd and returns a set of entries. + GetEntries(prefix string) ([]string, error) + + // WatchPrefix starts watching every change for given prefix in etcd. When an + // change is detected it will populate the responseChan when an *etcd.Response. + WatchPrefix(prefix string, responseChan chan *etcd.Response) +} + +type client struct { + keysAPI etcd.KeysAPI + ctx context.Context +} + +// ClientOptions defines options for the etcd client. +type ClientOptions struct { + Cert string + Key string + CaCert string + DialTimeout time.Duration + DialKeepAline time.Duration + HeaderTimeoutPerRequest time.Duration +} + +// NewClient returns an *etcd.Client with a connection to the named machines. +// It will return an error if a connection to the cluster cannot be made. +// The parameter machines needs to be a full URL with schemas. +// e.g. "http://localhost:2379" will work, but "localhost:2379" will not. +func NewClient(ctx context.Context, machines []string, options ClientOptions) (Client, error) { + var ( + c etcd.KeysAPI + err error + caCertCt []byte + tlsCert tls.Certificate + ) + + if options.Cert != "" && options.Key != "" { + tlsCert, err = tls.LoadX509KeyPair(options.Cert, options.Key) + if err != nil { + return nil, err + } + + caCertCt, err = ioutil.ReadFile(options.CaCert) + if err != nil { + return nil, err + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCertCt) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: caCertPool, + } + + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + Dial: func(network, addr string) (net.Conn, error) { + dial := &net.Dialer{ + Timeout: options.DialTimeout, + KeepAlive: options.DialKeepAline, + } + return dial.Dial(network, addr) + }, + } + + cfg := etcd.Config{ + Endpoints: machines, + Transport: transport, + HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest, + } + ce, err := etcd.New(cfg) + if err != nil { + return nil, err + } + c = etcd.NewKeysAPI(ce) + } else { + cfg := etcd.Config{ + Endpoints: machines, + Transport: etcd.DefaultTransport, + HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest, + } + ce, err := etcd.New(cfg) + if err != nil { + return nil, err + } + c = etcd.NewKeysAPI(ce) + } + + return &client{c, ctx}, nil +} + +// GetEntries implements the etcd Client interface. +func (c *client) GetEntries(key string) ([]string, error) { + resp, err := c.keysAPI.Get(c.ctx, key, &etcd.GetOptions{Recursive: true}) + if err != nil { + return nil, err + } + + entries := make([]string, len(resp.Node.Nodes)) + for i, node := range resp.Node.Nodes { + entries[i] = node.Value + } + return entries, nil +} + +// WatchPrefix implements the etcd Client interface. +func (c *client) WatchPrefix(prefix string, responseChan chan *etcd.Response) { + watch := c.keysAPI.Watcher(prefix, &etcd.WatcherOptions{AfterIndex: 0, Recursive: true}) + for { + res, err := watch.Next(c.ctx) + if err != nil { + return + } + responseChan <- res + } +} diff --git a/sd/etcd/subscriber.go b/sd/etcd/subscriber.go new file mode 100644 index 0000000..1d579eb --- /dev/null +++ b/sd/etcd/subscriber.go @@ -0,0 +1,74 @@ +package etcd + +import ( + etcd "github.com/coreos/etcd/client" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" + "github.com/go-kit/kit/sd/cache" +) + +// Subscriber yield endpoints stored in a certain etcd keyspace. Any kind of +// change in that keyspace is watched and will update the Subscriber endpoints. +type Subscriber struct { + client Client + prefix string + cache *cache.Cache + logger log.Logger + quitc chan struct{} +} + +var _ sd.Subscriber = &Subscriber{} + +// NewSubscriber returns an etcd subscriber. It will start watching the given +// prefix for changes, and update the endpoints. +func NewSubscriber(c Client, prefix string, factory sd.Factory, logger log.Logger) (*Subscriber, error) { + s := &Subscriber{ + client: c, + prefix: prefix, + cache: cache.New(factory, logger), + logger: logger, + quitc: make(chan struct{}), + } + + instances, err := s.client.GetEntries(s.prefix) + if err == nil { + logger.Log("prefix", s.prefix, "instances", len(instances)) + } else { + logger.Log("prefix", s.prefix, "err", err) + } + s.cache.Update(instances) + + go s.loop() + return s, nil +} + +func (s *Subscriber) loop() { + responseChan := make(chan *etcd.Response) + go s.client.WatchPrefix(s.prefix, responseChan) + for { + select { + case <-responseChan: + instances, err := s.client.GetEntries(s.prefix) + if err != nil { + s.logger.Log("msg", "failed to retrieve entries", "err", err) + continue + } + s.cache.Update(instances) + + case <-s.quitc: + return + } + } +} + +// Endpoints implements the Subscriber interface. +func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { + return s.cache.Endpoints(), nil +} + +// Stop terminates the Subscriber. +func (s *Subscriber) Stop() { + close(s.quitc) +} diff --git a/sd/etcd/subscriber_test.go b/sd/etcd/subscriber_test.go new file mode 100644 index 0000000..0073e1e --- /dev/null +++ b/sd/etcd/subscriber_test.go @@ -0,0 +1,89 @@ +package etcd + +import ( + "errors" + "io" + "testing" + + stdetcd "github.com/coreos/etcd/client" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" +) + +var ( + node = &stdetcd.Node{ + Key: "/foo", + Nodes: []*stdetcd.Node{ + {Key: "/foo/1", Value: "1:1"}, + {Key: "/foo/2", Value: "1:2"}, + }, + } + fakeResponse = &stdetcd.Response{ + Node: node, + } +) + +func TestSubscriber(t *testing.T) { + factory := func(string) (endpoint.Endpoint, io.Closer, error) { + return endpoint.Nop, nil, nil + } + + client := &fakeClient{ + responses: map[string]*stdetcd.Response{"/foo": fakeResponse}, + } + + s, err := NewSubscriber(client, "/foo", factory, log.NewNopLogger()) + if err != nil { + t.Fatal(err) + } + defer s.Stop() + + if _, err := s.Endpoints(); err != nil { + t.Fatal(err) + } +} + +func TestBadFactory(t *testing.T) { + factory := func(string) (endpoint.Endpoint, io.Closer, error) { + return nil, nil, errors.New("kaboom") + } + + client := &fakeClient{ + responses: map[string]*stdetcd.Response{"/foo": fakeResponse}, + } + + s, err := NewSubscriber(client, "/foo", factory, log.NewNopLogger()) + if err != nil { + t.Fatal(err) + } + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + + if want, have := 0, len(endpoints); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +type fakeClient struct { + responses map[string]*stdetcd.Response +} + +func (c *fakeClient) GetEntries(prefix string) ([]string, error) { + response, ok := c.responses[prefix] + if !ok { + return nil, errors.New("key not exist") + } + + entries := make([]string, len(response.Node.Nodes)) + for i, node := range response.Node.Nodes { + entries[i] = node.Value + } + return entries, nil +} + +func (c *fakeClient) WatchPrefix(prefix string, responseChan chan *stdetcd.Response) {} diff --git a/sd/factory.go b/sd/factory.go new file mode 100644 index 0000000..0e3e196 --- /dev/null +++ b/sd/factory.go @@ -0,0 +1,16 @@ +package sd + +import ( + "io" + + "github.com/go-kit/kit/endpoint" +) + +// Factory is a function that converts an instance string (e.g. host:port) to a +// specific endpoint. Instances that provide multiple endpoints require multiple +// factories. A factory also returns an io.Closer that's invoked when the +// instance goes away and needs to be cleaned up. +// +// Users are expected to provide their own factory functions that assume +// specific transports, or can deduce transports by parsing the instance string. +type Factory func(instance string) (endpoint.Endpoint, io.Closer, error) diff --git a/sd/fixed_subscriber.go b/sd/fixed_subscriber.go new file mode 100644 index 0000000..98fd503 --- /dev/null +++ b/sd/fixed_subscriber.go @@ -0,0 +1,9 @@ +package sd + +import "github.com/go-kit/kit/endpoint" + +// FixedSubscriber yields a fixed set of services. +type FixedSubscriber []endpoint.Endpoint + +// Endpoints implements Subscriber. +func (s FixedSubscriber) Endpoints() ([]endpoint.Endpoint, error) { return s, nil } diff --git a/sd/lb/balancer.go b/sd/lb/balancer.go new file mode 100644 index 0000000..40aa0ef --- /dev/null +++ b/sd/lb/balancer.go @@ -0,0 +1,15 @@ +package lb + +import ( + "errors" + + "github.com/go-kit/kit/endpoint" +) + +// Balancer yields endpoints according to some heuristic. +type Balancer interface { + Endpoint() (endpoint.Endpoint, error) +} + +// ErrNoEndpoints is returned when no qualifying endpoints are available. +var ErrNoEndpoints = errors.New("no endpoints available") diff --git a/sd/lb/doc.go b/sd/lb/doc.go new file mode 100644 index 0000000..82a9516 --- /dev/null +++ b/sd/lb/doc.go @@ -0,0 +1,5 @@ +// Package lb deals with client-side load balancing across multiple identical +// instances of services and endpoints. When combined with a service discovery +// system of record, it enables a more decentralized architecture, removing the +// need for separate load balancers like HAProxy. +package lb diff --git a/sd/lb/random.go b/sd/lb/random.go new file mode 100644 index 0000000..78b0956 --- /dev/null +++ b/sd/lb/random.go @@ -0,0 +1,32 @@ +package lb + +import ( + "math/rand" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/sd" +) + +// NewRandom returns a load balancer that selects services randomly. +func NewRandom(s sd.Subscriber, seed int64) Balancer { + return &random{ + s: s, + r: rand.New(rand.NewSource(seed)), + } +} + +type random struct { + s sd.Subscriber + r *rand.Rand +} + +func (r *random) Endpoint() (endpoint.Endpoint, error) { + endpoints, err := r.s.Endpoints() + if err != nil { + return nil, err + } + if len(endpoints) <= 0 { + return nil, ErrNoEndpoints + } + return endpoints[r.r.Intn(len(endpoints))], nil +} diff --git a/sd/lb/random_test.go b/sd/lb/random_test.go new file mode 100644 index 0000000..c9b0117 --- /dev/null +++ b/sd/lb/random_test.go @@ -0,0 +1,52 @@ +package lb + +import ( + "math" + "testing" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/sd" + "golang.org/x/net/context" +) + +func TestRandom(t *testing.T) { + var ( + n = 7 + endpoints = make([]endpoint.Endpoint, n) + counts = make([]int, n) + seed = int64(12345) + iterations = 1000000 + want = iterations / n + tolerance = want / 100 // 1% + ) + + for i := 0; i < n; i++ { + i0 := i + endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil } + } + + subscriber := sd.FixedSubscriber(endpoints) + balancer := NewRandom(subscriber, seed) + + for i := 0; i < iterations; i++ { + endpoint, _ := balancer.Endpoint() + endpoint(context.Background(), struct{}{}) + } + + for i, have := range counts { + delta := int(math.Abs(float64(want - have))) + if delta > tolerance { + t.Errorf("%d: want %d, have %d, delta %d > %d tolerance", i, want, have, delta, tolerance) + } + } +} + +func TestRandomNoEndpoints(t *testing.T) { + subscriber := sd.FixedSubscriber{} + balancer := NewRandom(subscriber, 1415926) + _, err := balancer.Endpoint() + if want, have := ErrNoEndpoints, err; want != have { + t.Errorf("want %v, have %v", want, have) + } + +} diff --git a/sd/lb/retry.go b/sd/lb/retry.go new file mode 100644 index 0000000..a933eeb --- /dev/null +++ b/sd/lb/retry.go @@ -0,0 +1,57 @@ +package lb + +import ( + "fmt" + "strings" + "time" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" +) + +// Retry wraps a service load balancer and returns an endpoint oriented load +// balancer for the specified service method. +// Requests to the endpoint will be automatically load balanced via the load +// balancer. Requests that return errors will be retried until they succeed, +// up to max times, or until the timeout is elapsed, whichever comes first. +func Retry(max int, timeout time.Duration, b Balancer) endpoint.Endpoint { + if b == nil { + panic("nil Balancer") + } + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + var ( + newctx, cancel = context.WithTimeout(ctx, timeout) + responses = make(chan interface{}, 1) + errs = make(chan error, 1) + a = []string{} + ) + defer cancel() + for i := 1; i <= max; i++ { + go func() { + e, err := b.Endpoint() + if err != nil { + errs <- err + return + } + response, err := e(newctx, request) + if err != nil { + errs <- err + return + } + responses <- response + }() + + select { + case <-newctx.Done(): + return nil, newctx.Err() + case response := <-responses: + return response, nil + case err := <-errs: + a = append(a, err.Error()) + continue + } + } + return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) + } +} diff --git a/sd/lb/retry_test.go b/sd/lb/retry_test.go new file mode 100644 index 0000000..07b1afd --- /dev/null +++ b/sd/lb/retry_test.go @@ -0,0 +1,90 @@ +package lb_test + +import ( + "errors" + "testing" + "time" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/sd" + loadbalancer "github.com/go-kit/kit/sd/lb" +) + +func TestRetryMaxTotalFail(t *testing.T) { + var ( + endpoints = sd.FixedSubscriber{} // no endpoints + lb = loadbalancer.NewRoundRobin(endpoints) + retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries + ctx = context.Background() + ) + if _, err := retry(ctx, struct{}{}); err == nil { + t.Errorf("expected error, got none") // should fail + } +} + +func TestRetryMaxPartialFail(t *testing.T) { + var ( + endpoints = []endpoint.Endpoint{ + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, + } + subscriber = sd.FixedSubscriber{ + 0: endpoints[0], + 1: endpoints[1], + 2: endpoints[2], + } + retries = len(endpoints) - 1 // not quite enough retries + lb = loadbalancer.NewRoundRobin(subscriber) + ctx = context.Background() + ) + if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil { + t.Errorf("expected error, got none") + } +} + +func TestRetryMaxSuccess(t *testing.T) { + var ( + endpoints = []endpoint.Endpoint{ + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, + } + subscriber = sd.FixedSubscriber{ + 0: endpoints[0], + 1: endpoints[1], + 2: endpoints[2], + } + retries = len(endpoints) // exactly enough retries + lb = loadbalancer.NewRoundRobin(subscriber) + ctx = context.Background() + ) + if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil { + t.Error(err) + } +} + +func TestRetryTimeout(t *testing.T) { + var ( + step = make(chan struct{}) + e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } + timeout = time.Millisecond + retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(sd.FixedSubscriber{0: e})) + errs = make(chan error, 1) + invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } + ) + + go func() { step <- struct{}{} }() // queue up a flush of the endpoint + invoke() // invoke the endpoint and trigger the flush + if err := <-errs; err != nil { // that should succeed + t.Error(err) + } + + go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush + invoke() // invoke the endpoint + if err := <-errs; err != context.DeadlineExceeded { // that should not succeed + t.Errorf("wanted %v, got none", context.DeadlineExceeded) + } +} diff --git a/sd/lb/round_robin.go b/sd/lb/round_robin.go new file mode 100644 index 0000000..74b86ca --- /dev/null +++ b/sd/lb/round_robin.go @@ -0,0 +1,34 @@ +package lb + +import ( + "sync/atomic" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/sd" +) + +// NewRoundRobin returns a load balancer that returns services in sequence. +func NewRoundRobin(s sd.Subscriber) Balancer { + return &roundRobin{ + s: s, + c: 0, + } +} + +type roundRobin struct { + s sd.Subscriber + c uint64 +} + +func (rr *roundRobin) Endpoint() (endpoint.Endpoint, error) { + endpoints, err := rr.s.Endpoints() + if err != nil { + return nil, err + } + if len(endpoints) <= 0 { + return nil, ErrNoEndpoints + } + old := atomic.AddUint64(&rr.c, 1) - 1 + idx := old % uint64(len(endpoints)) + return endpoints[idx], nil +} diff --git a/sd/lb/round_robin_test.go b/sd/lb/round_robin_test.go new file mode 100644 index 0000000..64a8baa --- /dev/null +++ b/sd/lb/round_robin_test.go @@ -0,0 +1,96 @@ +package lb + +import ( + "reflect" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/sd" +) + +func TestRoundRobin(t *testing.T) { + var ( + counts = []int{0, 0, 0} + endpoints = []endpoint.Endpoint{ + func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, + func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, + func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, + } + ) + + subscriber := sd.FixedSubscriber(endpoints) + balancer := NewRoundRobin(subscriber) + + for i, want := range [][]int{ + {1, 0, 0}, + {1, 1, 0}, + {1, 1, 1}, + {2, 1, 1}, + {2, 2, 1}, + {2, 2, 2}, + {3, 2, 2}, + } { + endpoint, err := balancer.Endpoint() + if err != nil { + t.Fatal(err) + } + endpoint(context.Background(), struct{}{}) + if have := counts; !reflect.DeepEqual(want, have) { + t.Fatalf("%d: want %v, have %v", i, want, have) + } + } +} + +func TestRoundRobinNoEndpoints(t *testing.T) { + subscriber := sd.FixedSubscriber{} + balancer := NewRoundRobin(subscriber) + _, err := balancer.Endpoint() + if want, have := ErrNoEndpoints, err; want != have { + t.Errorf("want %v, have %v", want, have) + } +} + +func TestRoundRobinNoRace(t *testing.T) { + balancer := NewRoundRobin(sd.FixedSubscriber([]endpoint.Endpoint{ + endpoint.Nop, + endpoint.Nop, + endpoint.Nop, + endpoint.Nop, + endpoint.Nop, + })) + + var ( + n = 100 + done = make(chan struct{}) + wg sync.WaitGroup + count uint64 + ) + + wg.Add(n) + + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + for { + select { + case <-done: + return + default: + _, _ = balancer.Endpoint() + atomic.AddUint64(&count, 1) + } + } + }() + } + + time.Sleep(time.Second) + close(done) + wg.Wait() + + t.Logf("made %d calls", atomic.LoadUint64(&count)) +} diff --git a/sd/registrar.go b/sd/registrar.go new file mode 100644 index 0000000..49a0c9f --- /dev/null +++ b/sd/registrar.go @@ -0,0 +1,13 @@ +package sd + +// Registrar registers instance information to a service discovery system when +// an instance becomes alive and healthy, and deregisters that information when +// the service becomes unhealthy or goes away. +// +// Registrar implementations exist for various service discovery systems. Note +// that identifying instance information (e.g. host:port) must be given via the +// concrete constructor; this interface merely signals lifecycle changes. +type Registrar interface { + Register() + Deregister() +} diff --git a/sd/subscriber.go b/sd/subscriber.go new file mode 100644 index 0000000..8267b51 --- /dev/null +++ b/sd/subscriber.go @@ -0,0 +1,11 @@ +package sd + +import "github.com/go-kit/kit/endpoint" + +// Subscriber listens to a service discovery system and yields a set of +// identical endpoints on demand. An error indicates a problem with connectivity +// to the service discovery system, or within the system itself; a subscriber +// may yield no endpoints without error. +type Subscriber interface { + Endpoints() ([]endpoint.Endpoint, error) +} diff --git a/sd/zk/client.go b/sd/zk/client.go new file mode 100644 index 0000000..8220817 --- /dev/null +++ b/sd/zk/client.go @@ -0,0 +1,231 @@ +package zk + +import ( + "errors" + "net" + "strings" + "time" + + "github.com/samuel/go-zookeeper/zk" + + "github.com/go-kit/kit/log" +) + +// DefaultACL is the default ACL to use for creating znodes. +var ( + DefaultACL = zk.WorldACL(zk.PermAll) + ErrInvalidCredentials = errors.New("invalid credentials provided") + ErrClientClosed = errors.New("client service closed") +) + +const ( + // DefaultConnectTimeout is the default timeout to establish a connection to + // a ZooKeeper node. + DefaultConnectTimeout = 2 * time.Second + // DefaultSessionTimeout is the default timeout to keep the current + // ZooKeeper session alive during a temporary disconnect. + DefaultSessionTimeout = 5 * time.Second +) + +// Client is a wrapper around a lower level ZooKeeper client implementation. +type Client interface { + // GetEntries should query the provided path in ZooKeeper, place a watch on + // it and retrieve data from its current child nodes. + GetEntries(path string) ([]string, <-chan zk.Event, error) + // CreateParentNodes should try to create the path in case it does not exist + // yet on ZooKeeper. + CreateParentNodes(path string) error + // Stop should properly shutdown the client implementation + Stop() +} + +type clientConfig struct { + logger log.Logger + acl []zk.ACL + credentials []byte + connectTimeout time.Duration + sessionTimeout time.Duration + rootNodePayload [][]byte + eventHandler func(zk.Event) +} + +// Option functions enable friendly APIs. +type Option func(*clientConfig) error + +type client struct { + *zk.Conn + clientConfig + active bool + quit chan struct{} +} + +// ACL returns an Option specifying a non-default ACL for creating parent nodes. +func ACL(acl []zk.ACL) Option { + return func(c *clientConfig) error { + c.acl = acl + return nil + } +} + +// Credentials returns an Option specifying a user/password combination which +// the client will use to authenticate itself with. +func Credentials(user, pass string) Option { + return func(c *clientConfig) error { + if user == "" || pass == "" { + return ErrInvalidCredentials + } + c.credentials = []byte(user + ":" + pass) + return nil + } +} + +// ConnectTimeout returns an Option specifying a non-default connection timeout +// when we try to establish a connection to a ZooKeeper server. +func ConnectTimeout(t time.Duration) Option { + return func(c *clientConfig) error { + if t.Seconds() < 1 { + return errors.New("invalid connect timeout (minimum value is 1 second)") + } + c.connectTimeout = t + return nil + } +} + +// SessionTimeout returns an Option specifying a non-default session timeout. +func SessionTimeout(t time.Duration) Option { + return func(c *clientConfig) error { + if t.Seconds() < 1 { + return errors.New("invalid session timeout (minimum value is 1 second)") + } + c.sessionTimeout = t + return nil + } +} + +// Payload returns an Option specifying non-default data values for each znode +// created by CreateParentNodes. +func Payload(payload [][]byte) Option { + return func(c *clientConfig) error { + c.rootNodePayload = payload + return nil + } +} + +// EventHandler returns an Option specifying a callback function to handle +// incoming zk.Event payloads (ZooKeeper connection events). +func EventHandler(handler func(zk.Event)) Option { + return func(c *clientConfig) error { + c.eventHandler = handler + return nil + } +} + +// NewClient returns a ZooKeeper client with a connection to the server cluster. +// It will return an error if the server cluster cannot be resolved. +func NewClient(servers []string, logger log.Logger, options ...Option) (Client, error) { + defaultEventHandler := func(event zk.Event) { + logger.Log("eventtype", event.Type.String(), "server", event.Server, "state", event.State.String(), "err", event.Err) + } + config := clientConfig{ + acl: DefaultACL, + connectTimeout: DefaultConnectTimeout, + sessionTimeout: DefaultSessionTimeout, + eventHandler: defaultEventHandler, + logger: logger, + } + for _, option := range options { + if err := option(&config); err != nil { + return nil, err + } + } + // dialer overrides the default ZooKeeper library Dialer so we can configure + // the connectTimeout. The current library has a hardcoded value of 1 second + // and there are reports of race conditions, due to slow DNS resolvers and + // other network latency issues. + dialer := func(network, address string, _ time.Duration) (net.Conn, error) { + return net.DialTimeout(network, address, config.connectTimeout) + } + conn, eventc, err := zk.Connect(servers, config.sessionTimeout, withLogger(logger), zk.WithDialer(dialer)) + + if err != nil { + return nil, err + } + + if len(config.credentials) > 0 { + err = conn.AddAuth("digest", config.credentials) + if err != nil { + return nil, err + } + } + + c := &client{conn, config, true, make(chan struct{})} + + // Start listening for incoming Event payloads and callback the set + // eventHandler. + go func() { + for { + select { + case event := <-eventc: + config.eventHandler(event) + case <-c.quit: + return + } + } + }() + return c, nil +} + +// CreateParentNodes implements the ZooKeeper Client interface. +func (c *client) CreateParentNodes(path string) error { + if !c.active { + return ErrClientClosed + } + if path[0] != '/' { + return zk.ErrInvalidPath + } + payload := []byte("") + pathString := "" + pathNodes := strings.Split(path, "/") + for i := 1; i < len(pathNodes); i++ { + if i <= len(c.rootNodePayload) { + payload = c.rootNodePayload[i-1] + } else { + payload = []byte("") + } + pathString += "/" + pathNodes[i] + _, err := c.Create(pathString, payload, 0, c.acl) + // not being able to create the node because it exists or not having + // sufficient rights is not an issue. It is ok for the node to already + // exist and/or us to only have read rights + if err != nil && err != zk.ErrNodeExists && err != zk.ErrNoAuth { + return err + } + } + return nil +} + +// GetEntries implements the ZooKeeper Client interface. +func (c *client) GetEntries(path string) ([]string, <-chan zk.Event, error) { + // retrieve list of child nodes for given path and add watch to path + znodes, _, eventc, err := c.ChildrenW(path) + + if err != nil { + return nil, eventc, err + } + + var resp []string + for _, znode := range znodes { + // retrieve payload for child znode and add to response array + if data, _, err := c.Get(path + "/" + znode); err == nil { + resp = append(resp, string(data)) + } + } + return resp, eventc, nil +} + +// Stop implements the ZooKeeper Client interface. +func (c *client) Stop() { + c.active = false + close(c.quit) + c.Close() +} diff --git a/sd/zk/client_test.go b/sd/zk/client_test.go new file mode 100644 index 0000000..fbb2a5a --- /dev/null +++ b/sd/zk/client_test.go @@ -0,0 +1,157 @@ +package zk + +import ( + "bytes" + "testing" + "time" + + stdzk "github.com/samuel/go-zookeeper/zk" + + "github.com/go-kit/kit/log" +) + +func TestNewClient(t *testing.T) { + var ( + acl = stdzk.WorldACL(stdzk.PermRead) + connectTimeout = 3 * time.Second + sessionTimeout = 20 * time.Second + payload = [][]byte{[]byte("Payload"), []byte("Test")} + ) + + c, err := NewClient( + []string{"FailThisInvalidHost!!!"}, + log.NewNopLogger(), + ) + if err == nil { + t.Errorf("expected error, got nil") + } + + hasFired := false + calledEventHandler := make(chan struct{}) + eventHandler := func(event stdzk.Event) { + if !hasFired { + // test is successful if this function has fired at least once + hasFired = true + close(calledEventHandler) + } + } + + c, err = NewClient( + []string{"localhost"}, + log.NewNopLogger(), + ACL(acl), + ConnectTimeout(connectTimeout), + SessionTimeout(sessionTimeout), + Payload(payload), + EventHandler(eventHandler), + ) + if err != nil { + t.Fatal(err) + } + defer c.Stop() + + clientImpl, ok := c.(*client) + if !ok { + t.Fatal("retrieved incorrect Client implementation") + } + if want, have := acl, clientImpl.acl; want[0] != have[0] { + t.Errorf("want %+v, have %+v", want, have) + } + if want, have := connectTimeout, clientImpl.connectTimeout; want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := sessionTimeout, clientImpl.sessionTimeout; want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := payload, clientImpl.rootNodePayload; bytes.Compare(want[0], have[0]) != 0 || bytes.Compare(want[1], have[1]) != 0 { + t.Errorf("want %s, have %s", want, have) + } + + select { + case <-calledEventHandler: + case <-time.After(100 * time.Millisecond): + t.Errorf("event handler never called") + } +} + +func TestOptions(t *testing.T) { + _, err := NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("valid", "credentials")) + if err != nil && err != stdzk.ErrNoServer { + t.Errorf("unexpected error: %v", err) + } + + _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("nopass", "")) + if want, have := err, ErrInvalidCredentials; want != have { + t.Errorf("want %v, have %v", want, have) + } + + _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), ConnectTimeout(0)) + if err == nil { + t.Errorf("expected connect timeout error") + } + + _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), SessionTimeout(0)) + if err == nil { + t.Errorf("expected connect timeout error") + } +} + +func TestCreateParentNodes(t *testing.T) { + payload := [][]byte{[]byte("Payload"), []byte("Test")} + + c, err := NewClient([]string{"localhost:65500"}, log.NewNopLogger()) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if c == nil { + t.Fatal("expected new Client, got nil") + } + + s, err := NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger()) + if err != stdzk.ErrNoServer { + t.Errorf("unexpected error: %v", err) + } + if s != nil { + t.Error("expected failed new Subscriber") + } + + s, err = NewSubscriber(c, "invalidpath", newFactory(""), log.NewNopLogger()) + if err != stdzk.ErrInvalidPath { + t.Errorf("unexpected error: %v", err) + } + _, _, err = c.GetEntries("/validpath") + if err != stdzk.ErrNoServer { + t.Errorf("unexpected error: %v", err) + } + + c.Stop() + + err = c.CreateParentNodes("/validpath") + if err != ErrClientClosed { + t.Errorf("unexpected error: %v", err) + } + + s, err = NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger()) + if err != ErrClientClosed { + t.Errorf("unexpected error: %v", err) + } + if s != nil { + t.Error("expected failed new Subscriber") + } + + c, err = NewClient([]string{"localhost:65500"}, log.NewNopLogger(), Payload(payload)) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if c == nil { + t.Fatal("expected new Client, got nil") + } + + s, err = NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger()) + if err != stdzk.ErrNoServer { + t.Errorf("unexpected error: %v", err) + } + if s != nil { + t.Error("expected failed new Subscriber") + } +} diff --git a/sd/zk/integration_test.go b/sd/zk/integration_test.go new file mode 100644 index 0000000..0e67679 --- /dev/null +++ b/sd/zk/integration_test.go @@ -0,0 +1,201 @@ +// +build integration + +package zk + +import ( + "bytes" + "flag" + "fmt" + "os" + "testing" + "time" + + stdzk "github.com/samuel/go-zookeeper/zk" +) + +var ( + host []string +) + +func TestMain(m *testing.M) { + flag.Parse() + + fmt.Println("Starting ZooKeeper server...") + + ts, err := stdzk.StartTestCluster(1, nil, nil) + if err != nil { + fmt.Printf("ZooKeeper server error: %v\n", err) + os.Exit(1) + } + + host = []string{fmt.Sprintf("localhost:%d", ts.Servers[0].Port)} + code := m.Run() + + ts.Stop() + os.Exit(code) +} + +func TestCreateParentNodesOnServer(t *testing.T) { + payload := [][]byte{[]byte("Payload"), []byte("Test")} + c1, err := NewClient(host, logger, Payload(payload)) + if err != nil { + t.Fatalf("Connect returned error: %v", err) + } + if c1 == nil { + t.Fatal("Expected pointer to client, got nil") + } + defer c1.Stop() + + s, err := NewSubscriber(c1, path, newFactory(""), logger) + if err != nil { + t.Fatalf("Unable to create Subscriber: %v", err) + } + defer s.Stop() + + services, err := s.Services() + if err != nil { + t.Fatal(err) + } + if want, have := 0, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } + + c2, err := NewClient(host, logger) + if err != nil { + t.Fatalf("Connect returned error: %v", err) + } + defer c2.Stop() + data, _, err := c2.(*client).Get(path) + if err != nil { + t.Fatal(err) + } + // test Client implementation of CreateParentNodes. It should have created + // our payload + if bytes.Compare(data, payload[1]) != 0 { + t.Errorf("want %s, have %s", payload[1], data) + } + +} + +func TestCreateBadParentNodesOnServer(t *testing.T) { + c, _ := NewClient(host, logger) + defer c.Stop() + + _, err := NewSubscriber(c, "invalid/path", newFactory(""), logger) + + if want, have := stdzk.ErrInvalidPath, err; want != have { + t.Errorf("want %v, have %v", want, have) + } +} + +func TestCredentials1(t *testing.T) { + acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret") + c, _ := NewClient(host, logger, ACL(acl), Credentials("user", "secret")) + defer c.Stop() + + _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger) + + if err != nil { + t.Fatal(err) + } +} + +func TestCredentials2(t *testing.T) { + acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret") + c, _ := NewClient(host, logger, ACL(acl)) + defer c.Stop() + + _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger) + + if err != stdzk.ErrNoAuth { + t.Errorf("want %v, have %v", stdzk.ErrNoAuth, err) + } +} + +func TestConnection(t *testing.T) { + c, _ := NewClient(host, logger) + c.Stop() + + _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger) + + if err != ErrClientClosed { + t.Errorf("want %v, have %v", ErrClientClosed, err) + } +} + +func TestGetEntriesOnServer(t *testing.T) { + var instancePayload = "protocol://hostname:port/routing" + + c1, err := NewClient(host, logger) + if err != nil { + t.Fatalf("Connect returned error: %v", err) + } + + defer c1.Stop() + + c2, err := NewClient(host, logger) + s, err := NewSubscriber(c2, path, newFactory(""), logger) + if err != nil { + t.Fatal(err) + } + defer c2.Stop() + + c2impl, _ := c2.(*client) + _, err = c2impl.Create( + path+"/instance1", + []byte(instancePayload), + stdzk.FlagEphemeral|stdzk.FlagSequence, + stdzk.WorldACL(stdzk.PermAll), + ) + if err != nil { + t.Fatalf("Unable to create test ephemeral znode 1: %v", err) + } + _, err = c2impl.Create( + path+"/instance2", + []byte(instancePayload+"2"), + stdzk.FlagEphemeral|stdzk.FlagSequence, + stdzk.WorldACL(stdzk.PermAll), + ) + if err != nil { + t.Fatalf("Unable to create test ephemeral znode 2: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + services, err := s.Services() + if err != nil { + t.Fatal(err) + } + if want, have := 2, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestGetEntriesPayloadOnServer(t *testing.T) { + c, err := NewClient(host, logger) + if err != nil { + t.Fatalf("Connect returned error: %v", err) + } + _, eventc, err := c.GetEntries(path) + if err != nil { + t.Fatal(err) + } + _, err = c.(*client).Create( + path+"/instance3", + []byte("just some payload"), + stdzk.FlagEphemeral|stdzk.FlagSequence, + stdzk.WorldACL(stdzk.PermAll), + ) + if err != nil { + t.Fatalf("Unable to create test ephemeral znode: %v", err) + } + select { + case event := <-eventc: + if want, have := stdzk.EventNodeChildrenChanged.String(), event.Type.String(); want != have { + t.Errorf("want %s, have %s", want, have) + } + case <-time.After(20 * time.Millisecond): + t.Errorf("expected incoming watch event, timeout occurred") + } + +} diff --git a/sd/zk/logwrapper.go b/sd/zk/logwrapper.go new file mode 100644 index 0000000..abb7b6d --- /dev/null +++ b/sd/zk/logwrapper.go @@ -0,0 +1,27 @@ +package zk + +import ( + "fmt" + + "github.com/samuel/go-zookeeper/zk" + + "github.com/go-kit/kit/log" +) + +// wrapLogger wraps a Go kit logger so we can use it as the logging service for +// the ZooKeeper library, which expects a Printf method to be available. +type wrapLogger struct { + log.Logger +} + +func (logger wrapLogger) Printf(format string, args ...interface{}) { + logger.Log("msg", fmt.Sprintf(format, args...)) +} + +// withLogger replaces the ZooKeeper library's default logging service with our +// own Go kit logger. +func withLogger(logger log.Logger) func(c *zk.Conn) { + return func(c *zk.Conn) { + c.SetLogger(wrapLogger{logger}) + } +} diff --git a/sd/zk/subscriber.go b/sd/zk/subscriber.go new file mode 100644 index 0000000..b9c67db --- /dev/null +++ b/sd/zk/subscriber.go @@ -0,0 +1,86 @@ +package zk + +import ( + "github.com/samuel/go-zookeeper/zk" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" + "github.com/go-kit/kit/sd/cache" +) + +// Subscriber yield endpoints stored in a certain ZooKeeper path. Any kind of +// change in that path is watched and will update the Subscriber endpoints. +type Subscriber struct { + client Client + path string + cache *cache.Cache + logger log.Logger + quitc chan struct{} +} + +var _ sd.Subscriber = &Subscriber{} + +// NewSubscriber returns a ZooKeeper subscriber. ZooKeeper will start watching +// the given path for changes and update the Subscriber endpoints. +func NewSubscriber(c Client, path string, factory sd.Factory, logger log.Logger) (*Subscriber, error) { + s := &Subscriber{ + client: c, + path: path, + cache: cache.New(factory, logger), + logger: logger, + quitc: make(chan struct{}), + } + + err := s.client.CreateParentNodes(s.path) + if err != nil { + return nil, err + } + + instances, eventc, err := s.client.GetEntries(s.path) + if err != nil { + logger.Log("path", s.path, "msg", "failed to retrieve entries", "err", err) + return nil, err + } + logger.Log("path", s.path, "instances", len(instances)) + s.cache.Update(instances) + + go s.loop(eventc) + + return s, nil +} + +func (s *Subscriber) loop(eventc <-chan zk.Event) { + var ( + instances []string + err error + ) + for { + select { + case <-eventc: + // We received a path update notification. Call GetEntries to + // retrieve child node data, and set a new watch, as ZK watches are + // one-time triggers. + instances, eventc, err = s.client.GetEntries(s.path) + if err != nil { + s.logger.Log("path", s.path, "msg", "failed to retrieve entries", "err", err) + continue + } + s.logger.Log("path", s.path, "instances", len(instances)) + s.cache.Update(instances) + + case <-s.quitc: + return + } + } +} + +// Endpoints implements the Subscriber interface. +func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { + return s.cache.Endpoints(), nil +} + +// Stop terminates the Subscriber. +func (s *Subscriber) Stop() { + close(s.quitc) +} diff --git a/sd/zk/subscriber_test.go b/sd/zk/subscriber_test.go new file mode 100644 index 0000000..79bdb84 --- /dev/null +++ b/sd/zk/subscriber_test.go @@ -0,0 +1,117 @@ +package zk + +import ( + "testing" + "time" +) + +func TestSubscriber(t *testing.T) { + client := newFakeClient() + + s, err := NewSubscriber(client, path, newFactory(""), logger) + if err != nil { + t.Fatalf("failed to create new Subscriber: %v", err) + } + defer s.Stop() + + if _, err := s.Endpoints(); err != nil { + t.Fatal(err) + } +} + +func TestBadFactory(t *testing.T) { + client := newFakeClient() + + s, err := NewSubscriber(client, path, newFactory("kaboom"), logger) + if err != nil { + t.Fatalf("failed to create new Subscriber: %v", err) + } + defer s.Stop() + + // instance1 came online + client.AddService(path+"/instance1", "kaboom") + + // instance2 came online + client.AddService(path+"/instance2", "zookeeper_node_data") + + if err = asyncTest(100*time.Millisecond, 1, s); err != nil { + t.Error(err) + } +} + +func TestServiceUpdate(t *testing.T) { + client := newFakeClient() + + s, err := NewSubscriber(client, path, newFactory(""), logger) + if err != nil { + t.Fatalf("failed to create new Subscriber: %v", err) + } + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + if want, have := 0, len(endpoints); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // instance1 came online + client.AddService(path+"/instance1", "zookeeper_node_data1") + + // instance2 came online + client.AddService(path+"/instance2", "zookeeper_node_data2") + + // we should have 2 instances + if err = asyncTest(100*time.Millisecond, 2, s); err != nil { + t.Error(err) + } + + // TODO(pb): this bit is flaky + // + //// watch triggers an error... + //client.SendErrorOnWatch() + // + //// test if error was consumed + //if err = client.ErrorIsConsumedWithin(100 * time.Millisecond); err != nil { + // t.Error(err) + //} + + // instance3 came online + client.AddService(path+"/instance3", "zookeeper_node_data3") + + // we should have 3 instances + if err = asyncTest(100*time.Millisecond, 3, s); err != nil { + t.Error(err) + } + + // instance1 goes offline + client.RemoveService(path + "/instance1") + + // instance2 goes offline + client.RemoveService(path + "/instance2") + + // we should have 1 instance + if err = asyncTest(100*time.Millisecond, 1, s); err != nil { + t.Error(err) + } +} + +func TestBadSubscriberCreate(t *testing.T) { + client := newFakeClient() + client.SendErrorOnWatch() + s, err := NewSubscriber(client, path, newFactory(""), logger) + if err == nil { + t.Error("expected error on new Subscriber") + } + if s != nil { + t.Error("expected Subscriber not to be created") + } + s, err = NewSubscriber(client, "BadPath", newFactory(""), logger) + if err == nil { + t.Error("expected error on new Subscriber") + } + if s != nil { + t.Error("expected Subscriber not to be created") + } +} diff --git a/sd/zk/util_test.go b/sd/zk/util_test.go new file mode 100644 index 0000000..2a4e1fe --- /dev/null +++ b/sd/zk/util_test.go @@ -0,0 +1,126 @@ +package zk + +import ( + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/samuel/go-zookeeper/zk" + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" +) + +var ( + path = "/gokit.test/service.name" + e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } + logger = log.NewNopLogger() +) + +type fakeClient struct { + mtx sync.Mutex + ch chan zk.Event + responses map[string]string + result bool +} + +func newFakeClient() *fakeClient { + return &fakeClient{ + ch: make(chan zk.Event, 1), + responses: make(map[string]string), + result: true, + } +} + +func (c *fakeClient) CreateParentNodes(path string) error { + if path == "BadPath" { + return errors.New("dummy error") + } + return nil +} + +func (c *fakeClient) GetEntries(path string) ([]string, <-chan zk.Event, error) { + c.mtx.Lock() + defer c.mtx.Unlock() + if c.result == false { + c.result = true + return []string{}, c.ch, errors.New("dummy error") + } + responses := []string{} + for _, data := range c.responses { + responses = append(responses, data) + } + return responses, c.ch, nil +} + +func (c *fakeClient) AddService(node, data string) { + c.mtx.Lock() + defer c.mtx.Unlock() + c.responses[node] = data + c.ch <- zk.Event{} +} + +func (c *fakeClient) RemoveService(node string) { + c.mtx.Lock() + defer c.mtx.Unlock() + delete(c.responses, node) + c.ch <- zk.Event{} +} + +func (c *fakeClient) SendErrorOnWatch() { + c.mtx.Lock() + defer c.mtx.Unlock() + c.result = false + c.ch <- zk.Event{} +} + +func (c *fakeClient) ErrorIsConsumedWithin(timeout time.Duration) error { + t := time.After(timeout) + for { + select { + case <-t: + return fmt.Errorf("expected error not consumed after timeout %s", timeout) + default: + c.mtx.Lock() + if c.result == false { + c.mtx.Unlock() + return nil + } + c.mtx.Unlock() + } + } +} + +func (c *fakeClient) Stop() {} + +func newFactory(fakeError string) sd.Factory { + return func(instance string) (endpoint.Endpoint, io.Closer, error) { + if fakeError == instance { + return nil, nil, errors.New(fakeError) + } + return endpoint.Nop, nil, nil + } +} + +func asyncTest(timeout time.Duration, want int, s *Subscriber) (err error) { + var endpoints []endpoint.Endpoint + have := -1 // want can never be <0 + t := time.After(timeout) + for { + select { + case <-t: + return fmt.Errorf("want %d, have %d (timeout %s)", want, have, timeout.String()) + default: + endpoints, err = s.Endpoints() + have = len(endpoints) + if err != nil || want == have { + return + } + time.Sleep(timeout / 10) + } + } +}