1
0
mirror of https://github.com/go-kit/kit.git synced 2025-07-15 01:04:44 +02:00

sd: port, without service.Service

This commit is contained in:
Peter Bourgon
2016-05-25 15:39:18 -06:00
parent 7de0f49c7f
commit 9a19822c46
37 changed files with 2753 additions and 0 deletions

View File

@ -10,6 +10,10 @@ import (
// It represents a single RPC method.
type Endpoint func(ctx context.Context, request interface{}) (response interface{}, err error)
// Nop is an endpoint that does nothing and returns a nil error.
// Useful for tests.
func Nop(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
// Middleware is a chainable behavior modifier for endpoints.
type Middleware func(Endpoint) Endpoint

29
sd/cache/benchmark_test.go vendored Normal file
View File

@ -0,0 +1,29 @@
package cache
import (
"io"
"testing"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
)
func BenchmarkEndpoints(b *testing.B) {
var (
ca = make(closer)
cb = make(closer)
cmap = map[string]io.Closer{"a": ca, "b": cb}
factory = func(instance string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, cmap[instance], nil }
c = New(factory, log.NewNopLogger())
)
b.ReportAllocs()
c.Update([]string{"a", "b"})
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c.Endpoints()
}
})
}

97
sd/cache/cache.go vendored Normal file
View File

@ -0,0 +1,97 @@
package cache
import (
"io"
"sort"
"sync"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/sd"
)
// Cache collects the most recent set of endpoints from a service discovery
// system via a subscriber, and makes them available to consumers. Cache is
// meant to be embedded inside of a concrete subscriber, and can serve Service
// invocations directly.
type Cache struct {
mtx sync.RWMutex
factory sd.Factory
cache map[string]endpointCloser
slice []endpoint.Endpoint
logger log.Logger
}
type endpointCloser struct {
endpoint.Endpoint
io.Closer
}
// New returns a new, empty endpoint cache.
func New(factory sd.Factory, logger log.Logger) *Cache {
return &Cache{
factory: factory,
cache: map[string]endpointCloser{},
logger: logger,
}
}
// Update should be invoked by clients with a complete set of current instance
// strings whenever that set changes. The cache manufactures new endpoints via
// the factory, closes old endpoints when they disappear, and persists existing
// endpoints if they survive through an update.
func (c *Cache) Update(instances []string) {
c.mtx.Lock()
defer c.mtx.Unlock()
// Deterministic order (for later).
sort.Strings(instances)
// Produce the current set of services.
cache := make(map[string]endpointCloser, len(instances))
for _, instance := range instances {
// If it already exists, just copy it over.
if sc, ok := c.cache[instance]; ok {
cache[instance] = sc
delete(c.cache, instance)
continue
}
// If it doesn't exist, create it.
service, closer, err := c.factory(instance)
if err != nil {
c.logger.Log("instance", instance, "err", err)
continue
}
cache[instance] = endpointCloser{service, closer}
}
// Close any leftover endpoints.
for _, sc := range c.cache {
if sc.Closer != nil {
sc.Closer.Close()
}
}
// Populate the slice of endpoints.
slice := make([]endpoint.Endpoint, 0, len(cache))
for _, instance := range instances {
// A bad factory may mean an instance is not present.
if _, ok := cache[instance]; !ok {
continue
}
slice = append(slice, cache[instance].Endpoint)
}
// Swap and trigger GC for old copies.
c.slice = slice
c.cache = cache
}
// Endpoints yields the current set of (presumably identical) endpoints, ordered
// lexicographically by the corresponding instance string.
func (c *Cache) Endpoints() []endpoint.Endpoint {
c.mtx.RLock()
defer c.mtx.RUnlock()
return c.slice
}

91
sd/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,91 @@
package cache
import (
"errors"
"io"
"testing"
"time"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
)
func TestCache(t *testing.T) {
var (
ca = make(closer)
cb = make(closer)
c = map[string]io.Closer{"a": ca, "b": cb}
f = func(instance string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, c[instance], nil }
cache = New(f, log.NewNopLogger())
)
// Populate
cache.Update([]string{"a", "b"})
select {
case <-ca:
t.Errorf("endpoint a closed, not good")
case <-cb:
t.Errorf("endpoint b closed, not good")
case <-time.After(time.Millisecond):
t.Logf("no closures yet, good")
}
if want, have := 2, len(cache.Endpoints()); want != have {
t.Errorf("want %d, have %d", want, have)
}
// Duplicate, should be no-op
cache.Update([]string{"a", "b"})
select {
case <-ca:
t.Errorf("endpoint a closed, not good")
case <-cb:
t.Errorf("endpoint b closed, not good")
case <-time.After(time.Millisecond):
t.Logf("no closures yet, good")
}
if want, have := 2, len(cache.Endpoints()); want != have {
t.Errorf("want %d, have %d", want, have)
}
// Delete b
go cache.Update([]string{"a"})
select {
case <-ca:
t.Errorf("endpoint a closed, not good")
case <-cb:
t.Logf("endpoint b closed, good")
case <-time.After(time.Second):
t.Errorf("didn't close the deleted instance in time")
}
if want, have := 1, len(cache.Endpoints()); want != have {
t.Errorf("want %d, have %d", want, have)
}
// Delete a
go cache.Update([]string{})
select {
// case <-cb: will succeed, as it's closed
case <-ca:
t.Logf("endpoint a closed, good")
case <-time.After(time.Second):
t.Errorf("didn't close the deleted instance in time")
}
if want, have := 0, len(cache.Endpoints()); want != have {
t.Errorf("want %d, have %d", want, have)
}
}
func TestBadFactory(t *testing.T) {
cache := New(func(string) (endpoint.Endpoint, io.Closer, error) {
return nil, nil, errors.New("bad factory")
}, log.NewNopLogger())
cache.Update([]string{"foo:1234", "bar:5678"})
if want, have := 0, len(cache.Endpoints()); want != have {
t.Errorf("want %d, have %d", want, have)
}
}
type closer chan struct{}
func (c closer) Close() error { close(c); return nil }

37
sd/consul/client.go Normal file
View File

@ -0,0 +1,37 @@
package consul
import consul "github.com/hashicorp/consul/api"
// Client is a wrapper around the Consul API.
type Client interface {
// Register a service with the local agent.
Register(r *consul.AgentServiceRegistration) error
// Deregister a service with the local agent.
Deregister(r *consul.AgentServiceRegistration) error
// Service
Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error)
}
type client struct {
consul *consul.Client
}
// NewClient returns an implementation of the Client interface, wrapping a
// concrete Consul client.
func NewClient(c *consul.Client) Client {
return &client{consul: c}
}
func (c *client) Register(r *consul.AgentServiceRegistration) error {
return c.consul.Agent().ServiceRegister(r)
}
func (c *client) Deregister(r *consul.AgentServiceRegistration) error {
return c.consul.Agent().ServiceDeregister(r.ID)
}
func (c *client) Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) {
return c.consul.Health().Service(service, tag, passingOnly, queryOpts)
}

156
sd/consul/client_test.go Normal file
View File

@ -0,0 +1,156 @@
package consul
import (
"errors"
"io"
"reflect"
"testing"
stdconsul "github.com/hashicorp/consul/api"
"golang.org/x/net/context"
"github.com/go-kit/kit/endpoint"
)
func TestClientRegistration(t *testing.T) {
c := newTestClient(nil)
services, _, err := c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{})
if err != nil {
t.Error(err)
}
if want, have := 0, len(services); want != have {
t.Errorf("want %d, have %d", want, have)
}
if err := c.Register(testRegistration); err != nil {
t.Error(err)
}
if err := c.Register(testRegistration); err == nil {
t.Errorf("want error, have %v", err)
}
services, _, err = c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{})
if err != nil {
t.Error(err)
}
if want, have := 1, len(services); want != have {
t.Errorf("want %d, have %d", want, have)
}
if err := c.Deregister(testRegistration); err != nil {
t.Error(err)
}
if err := c.Deregister(testRegistration); err == nil {
t.Errorf("want error, have %v", err)
}
services, _, err = c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{})
if err != nil {
t.Error(err)
}
if want, have := 0, len(services); want != have {
t.Errorf("want %d, have %d", want, have)
}
}
type testClient struct {
entries []*stdconsul.ServiceEntry
}
func newTestClient(entries []*stdconsul.ServiceEntry) *testClient {
return &testClient{
entries: entries,
}
}
var _ Client = &testClient{}
func (c *testClient) Service(service, tag string, _ bool, opts *stdconsul.QueryOptions) ([]*stdconsul.ServiceEntry, *stdconsul.QueryMeta, error) {
var results []*stdconsul.ServiceEntry
for _, entry := range c.entries {
if entry.Service.Service != service {
continue
}
if tag != "" {
tagMap := map[string]struct{}{}
for _, t := range entry.Service.Tags {
tagMap[t] = struct{}{}
}
if _, ok := tagMap[tag]; !ok {
continue
}
}
results = append(results, entry)
}
return results, &stdconsul.QueryMeta{}, nil
}
func (c *testClient) Register(r *stdconsul.AgentServiceRegistration) error {
toAdd := registration2entry(r)
for _, entry := range c.entries {
if reflect.DeepEqual(*entry, *toAdd) {
return errors.New("duplicate")
}
}
c.entries = append(c.entries, toAdd)
return nil
}
func (c *testClient) Deregister(r *stdconsul.AgentServiceRegistration) error {
toDelete := registration2entry(r)
var newEntries []*stdconsul.ServiceEntry
for _, entry := range c.entries {
if reflect.DeepEqual(*entry, *toDelete) {
continue
}
newEntries = append(newEntries, entry)
}
if len(newEntries) == len(c.entries) {
return errors.New("not found")
}
c.entries = newEntries
return nil
}
func registration2entry(r *stdconsul.AgentServiceRegistration) *stdconsul.ServiceEntry {
return &stdconsul.ServiceEntry{
Node: &stdconsul.Node{
Node: "some-node",
Address: r.Address,
},
Service: &stdconsul.AgentService{
ID: r.ID,
Service: r.Name,
Tags: r.Tags,
Port: r.Port,
Address: r.Address,
},
// Checks ignored
}
}
func testFactory(instance string) (endpoint.Endpoint, io.Closer, error) {
return func(context.Context, interface{}) (interface{}, error) {
return instance, nil
}, nil, nil
}
var testRegistration = &stdconsul.AgentServiceRegistration{
ID: "my-id",
Name: "my-name",
Tags: []string{"my-tag-1", "my-tag-2"},
Port: 12345,
Address: "my-address",
}

View File

@ -0,0 +1,86 @@
// +build integration
package consul
import (
"io"
"os"
"testing"
"time"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/service"
stdconsul "github.com/hashicorp/consul/api"
)
func TestIntegration(t *testing.T) {
// Connect to Consul.
// docker run -p 8500:8500 progrium/consul -server -bootstrap
consulAddr := os.Getenv("CONSUL_ADDRESS")
if consulAddr == "" {
t.Fatal("CONSUL_ADDRESS is not set")
}
stdClient, err := stdconsul.NewClient(&stdconsul.Config{
Address: consulAddr,
})
if err != nil {
t.Fatal(err)
}
client := NewClient(stdClient)
logger := log.NewLogfmtLogger(os.Stderr)
// Produce a fake service registration.
r := &stdconsul.AgentServiceRegistration{
ID: "my-service-ID",
Name: "my-service-name",
Tags: []string{"alpha", "beta"},
Port: 12345,
Address: "my-address",
EnableTagOverride: false,
// skipping check(s)
}
// Build a subscriber on r.Name + r.Tags.
factory := func(instance string) (service.Service, io.Closer, error) {
t.Logf("factory invoked for %q", instance)
return service.Fixed{}, nil, nil
}
subscriber, err := NewSubscriber(
client,
factory,
log.NewContext(logger).With("component", "subscriber"),
r.Name,
r.Tags,
true,
)
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Second)
// Before we publish, we should have no services.
services, err := subscriber.Services()
if err != nil {
t.Error(err)
}
if want, have := 0, len(services); want != have {
t.Errorf("want %d, have %d", want, have)
}
// Build a registrar for r.
registrar := NewRegistrar(client, r, log.NewContext(logger).With("component", "registrar"))
registrar.Register()
defer registrar.Deregister()
time.Sleep(time.Second)
// Now we should have one active service.
services, err = subscriber.Services()
if err != nil {
t.Error(err)
}
if want, have := 1, len(services); want != have {
t.Errorf("want %d, have %d", want, have)
}
}

44
sd/consul/registrar.go Normal file
View File

@ -0,0 +1,44 @@
package consul
import (
"fmt"
stdconsul "github.com/hashicorp/consul/api"
"github.com/go-kit/kit/log"
)
// Registrar registers service instance liveness information to Consul.
type Registrar struct {
client Client
registration *stdconsul.AgentServiceRegistration
logger log.Logger
}
// NewRegistrar returns a Consul Registrar acting on the provided catalog
// registration.
func NewRegistrar(client Client, r *stdconsul.AgentServiceRegistration, logger log.Logger) *Registrar {
return &Registrar{
client: client,
registration: r,
logger: log.NewContext(logger).With("service", r.Name, "tags", fmt.Sprint(r.Tags), "address", r.Address),
}
}
// Register implements sd.Registrar interface.
func (p *Registrar) Register() {
if err := p.client.Register(p.registration); err != nil {
p.logger.Log("err", err)
} else {
p.logger.Log("action", "register")
}
}
// Deregister implements sd.Registrar interface.
func (p *Registrar) Deregister() {
if err := p.client.Deregister(p.registration); err != nil {
p.logger.Log("err", err)
} else {
p.logger.Log("action", "deregister")
}
}

View File

@ -0,0 +1,27 @@
package consul
import (
"testing"
stdconsul "github.com/hashicorp/consul/api"
"github.com/go-kit/kit/log"
)
func TestRegistrar(t *testing.T) {
client := newTestClient([]*stdconsul.ServiceEntry{})
p := NewRegistrar(client, testRegistration, log.NewNopLogger())
if want, have := 0, len(client.entries); want != have {
t.Errorf("want %d, have %d", want, have)
}
p.Register()
if want, have := 1, len(client.entries); want != have {
t.Errorf("want %d, have %d", want, have)
}
p.Deregister()
if want, have := 0, len(client.entries); want != have {
t.Errorf("want %d, have %d", want, have)
}
}

166
sd/consul/subscriber.go Normal file
View File

@ -0,0 +1,166 @@
package consul
import (
"fmt"
"io"
consul "github.com/hashicorp/consul/api"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/sd"
"github.com/go-kit/kit/sd/cache"
)
const defaultIndex = 0
// Subscriber yields endpoints for a service in Consul. Updates to the service
// are watched and will update the Subscriber endpoints.
type Subscriber struct {
cache *cache.Cache
client Client
logger log.Logger
service string
tags []string
passingOnly bool
endpointsc chan []endpoint.Endpoint
quitc chan struct{}
}
var _ sd.Subscriber = &Subscriber{}
// NewSubscriber returns a Consul subscriber which returns endpoints for the
// requested service. It only returns instances for which all of the passed tags
// are present.
func NewSubscriber(client Client, factory sd.Factory, logger log.Logger, service string, tags []string, passingOnly bool) (*Subscriber, error) {
s := &Subscriber{
cache: cache.New(factory, logger),
client: client,
logger: log.NewContext(logger).With("service", service, "tags", fmt.Sprint(tags)),
service: service,
tags: tags,
passingOnly: passingOnly,
quitc: make(chan struct{}),
}
instances, index, err := s.getInstances(defaultIndex, nil)
if err == nil {
s.logger.Log("instances", len(instances))
} else {
s.logger.Log("err", err)
}
s.cache.Update(instances)
go s.loop(index)
return s, nil
}
// Endpoints implements the Subscriber interface.
func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) {
return s.cache.Endpoints(), nil
}
// Stop terminates the subscriber.
func (s *Subscriber) Stop() {
close(s.quitc)
}
func (s *Subscriber) loop(lastIndex uint64) {
var (
instances []string
err error
)
for {
instances, lastIndex, err = s.getInstances(lastIndex, s.quitc)
switch {
case err == io.EOF:
return // stopped via quitc
case err != nil:
s.logger.Log("err", err)
default:
s.cache.Update(instances)
}
}
}
func (s *Subscriber) getInstances(lastIndex uint64, interruptc chan struct{}) ([]string, uint64, error) {
tag := ""
if len(s.tags) > 0 {
tag = s.tags[0]
}
// Consul doesn't support more than one tag in its service query method.
// https://github.com/hashicorp/consul/issues/294
// Hashi suggest prepared queries, but they don't support blocking.
// https://www.consul.io/docs/agent/http/query.html#execute
// If we want blocking for efficiency, we must filter tags manually.
type response struct {
instances []string
index uint64
}
var (
errc = make(chan error, 1)
resc = make(chan response, 1)
)
go func() {
entries, meta, err := s.client.Service(s.service, tag, s.passingOnly, &consul.QueryOptions{
WaitIndex: lastIndex,
})
if err != nil {
errc <- err
return
}
if len(s.tags) > 1 {
entries = filterEntries(entries, s.tags[1:]...)
}
resc <- response{
instances: makeInstances(entries),
index: meta.LastIndex,
}
}()
select {
case err := <-errc:
return nil, 0, err
case res := <-resc:
return res.instances, res.index, nil
case <-interruptc:
return nil, 0, io.EOF
}
}
func filterEntries(entries []*consul.ServiceEntry, tags ...string) []*consul.ServiceEntry {
var es []*consul.ServiceEntry
ENTRIES:
for _, entry := range entries {
ts := make(map[string]struct{}, len(entry.Service.Tags))
for _, tag := range entry.Service.Tags {
ts[tag] = struct{}{}
}
for _, tag := range tags {
if _, ok := ts[tag]; !ok {
continue ENTRIES
}
}
es = append(es, entry)
}
return es
}
func makeInstances(entries []*consul.ServiceEntry) []string {
instances := make([]string, len(entries))
for i, entry := range entries {
addr := entry.Node.Address
if entry.Service.Address != "" {
addr = entry.Service.Address
}
instances[i] = fmt.Sprintf("%s:%d", addr, entry.Service.Port)
}
return instances
}

View File

@ -0,0 +1,150 @@
package consul
import (
"testing"
consul "github.com/hashicorp/consul/api"
"golang.org/x/net/context"
"github.com/go-kit/kit/log"
)
var consulState = []*consul.ServiceEntry{
{
Node: &consul.Node{
Address: "10.0.0.0",
Node: "app00.local",
},
Service: &consul.AgentService{
ID: "search-api-0",
Port: 8000,
Service: "search",
Tags: []string{
"api",
"v1",
},
},
},
{
Node: &consul.Node{
Address: "10.0.0.1",
Node: "app01.local",
},
Service: &consul.AgentService{
ID: "search-api-1",
Port: 8001,
Service: "search",
Tags: []string{
"api",
"v2",
},
},
},
{
Node: &consul.Node{
Address: "10.0.0.1",
Node: "app01.local",
},
Service: &consul.AgentService{
Address: "10.0.0.10",
ID: "search-db-0",
Port: 9000,
Service: "search",
Tags: []string{
"db",
},
},
},
}
func TestSubscriber(t *testing.T) {
var (
logger = log.NewNopLogger()
client = newTestClient(consulState)
)
s, err := NewSubscriber(client, testFactory, logger, "search", []string{"api"}, true)
if err != nil {
t.Fatal(err)
}
defer s.Stop()
endpoints, err := s.Endpoints()
if err != nil {
t.Fatal(err)
}
if want, have := 2, len(endpoints); want != have {
t.Errorf("want %d, have %d", want, have)
}
}
func TestSubscriberNoService(t *testing.T) {
var (
logger = log.NewNopLogger()
client = newTestClient(consulState)
)
s, err := NewSubscriber(client, testFactory, logger, "feed", []string{}, true)
if err != nil {
t.Fatal(err)
}
defer s.Stop()
endpoints, err := s.Endpoints()
if err != nil {
t.Fatal(err)
}
if want, have := 0, len(endpoints); want != have {
t.Fatalf("want %d, have %d", want, have)
}
}
func TestSubscriberWithTags(t *testing.T) {
var (
logger = log.NewNopLogger()
client = newTestClient(consulState)
)
s, err := NewSubscriber(client, testFactory, logger, "search", []string{"api", "v2"}, true)
if err != nil {
t.Fatal(err)
}
defer s.Stop()
endpoints, err := s.Endpoints()
if err != nil {
t.Fatal(err)
}
if want, have := 1, len(endpoints); want != have {
t.Fatalf("want %d, have %d", want, have)
}
}
func TestSubscriberAddressOverride(t *testing.T) {
s, err := NewSubscriber(newTestClient(consulState), testFactory, log.NewNopLogger(), "search", []string{"db"}, true)
if err != nil {
t.Fatal(err)
}
defer s.Stop()
endpoints, err := s.Endpoints()
if err != nil {
t.Fatal(err)
}
if want, have := 1, len(endpoints); want != have {
t.Fatalf("want %d, have %d", want, have)
}
response, err := endpoints[0](context.Background(), struct{}{})
if err != nil {
t.Fatal(err)
}
if want, have := "10.0.0.10:9000", response.(string); want != have {
t.Errorf("want %q, have %q", want, have)
}
}

7
sd/dnssrv/lookup.go Normal file
View File

@ -0,0 +1,7 @@
package dnssrv
import "net"
// Lookup is a function that resolves a DNS SRV record to multiple addresses.
// It has the same signature as net.LookupSRV.
type Lookup func(service, proto, name string) (cname string, addrs []*net.SRV, err error)

100
sd/dnssrv/subscriber.go Normal file
View File

@ -0,0 +1,100 @@
package dnssrv
import (
"fmt"
"net"
"time"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/sd"
"github.com/go-kit/kit/sd/cache"
)
// Subscriber yields endpoints taken from the named DNS SRV record. The name is
// resolved on a fixed schedule. Priorities and weights are ignored.
type Subscriber struct {
name string
cache *cache.Cache
logger log.Logger
quit chan struct{}
}
// NewSubscriber returns a DNS SRV subscriber.
func NewSubscriber(
name string,
ttl time.Duration,
factory sd.Factory,
logger log.Logger,
) *Subscriber {
return NewSubscriberDetailed(name, time.NewTicker(ttl), net.LookupSRV, factory, logger)
}
// NewSubscriberDetailed is the same as NewSubscriber, but allows users to
// provide an explicit lookup refresh ticker instead of a TTL, and specify the
// lookup function instead of using net.LookupSRV.
func NewSubscriberDetailed(
name string,
refresh *time.Ticker,
lookup Lookup,
factory sd.Factory,
logger log.Logger,
) *Subscriber {
p := &Subscriber{
name: name,
cache: cache.New(factory, logger),
logger: logger,
quit: make(chan struct{}),
}
instances, err := p.resolve(lookup)
if err == nil {
logger.Log("name", name, "instances", len(instances))
} else {
logger.Log("name", name, "err", err)
}
p.cache.Update(instances)
go p.loop(refresh, lookup)
return p
}
// Stop terminates the Subscriber.
func (p *Subscriber) Stop() {
close(p.quit)
}
func (p *Subscriber) loop(t *time.Ticker, lookup Lookup) {
defer t.Stop()
for {
select {
case <-t.C:
instances, err := p.resolve(lookup)
if err != nil {
p.logger.Log("name", p.name, "err", err)
continue // don't replace potentially-good with bad
}
p.cache.Update(instances)
case <-p.quit:
return
}
}
}
// Endpoints implements the Subscriber interface.
func (p *Subscriber) Endpoints() ([]endpoint.Endpoint, error) {
return p.cache.Endpoints(), nil
}
func (p *Subscriber) resolve(lookup Lookup) ([]string, error) {
_, addrs, err := lookup("", "", p.name)
if err != nil {
return []string{}, err
}
instances := make([]string, len(addrs))
for i, addr := range addrs {
instances[i] = net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port))
}
return instances, nil
}

View File

@ -0,0 +1,85 @@
package dnssrv
import (
"io"
"net"
"sync/atomic"
"testing"
"time"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
)
func TestRefresh(t *testing.T) {
name := "some.service.internal"
ticker := time.NewTicker(time.Second)
ticker.Stop()
tickc := make(chan time.Time)
ticker.C = tickc
var lookups uint64
records := []*net.SRV{}
lookup := func(service, proto, name string) (string, []*net.SRV, error) {
t.Logf("lookup(%q, %q, %q)", service, proto, name)
atomic.AddUint64(&lookups, 1)
return "cname", records, nil
}
var generates uint64
factory := func(instance string) (endpoint.Endpoint, io.Closer, error) {
t.Logf("factory(%q)", instance)
atomic.AddUint64(&generates, 1)
return endpoint.Nop, nopCloser{}, nil
}
subscriber := NewSubscriberDetailed(name, ticker, lookup, factory, log.NewNopLogger())
defer subscriber.Stop()
// First lookup, empty
endpoints, err := subscriber.Endpoints()
if err != nil {
t.Error(err)
}
if want, have := 0, len(endpoints); want != have {
t.Errorf("want %d, have %d", want, have)
}
if want, have := uint64(1), atomic.LoadUint64(&lookups); want != have {
t.Errorf("want %d, have %d", want, have)
}
if want, have := uint64(0), atomic.LoadUint64(&generates); want != have {
t.Errorf("want %d, have %d", want, have)
}
// Load some records and lookup again
records = []*net.SRV{
&net.SRV{Target: "1.0.0.1", Port: 1001},
&net.SRV{Target: "1.0.0.2", Port: 1002},
&net.SRV{Target: "1.0.0.3", Port: 1003},
}
tickc <- time.Now()
// There is a race condition where the subscriber.Endpoints call below
// invokes the cache before it is updated by the tick above.
// TODO(pb): solve by running the read through the loop goroutine.
time.Sleep(100 * time.Millisecond)
endpoints, err = subscriber.Endpoints()
if err != nil {
t.Error(err)
}
if want, have := 3, len(endpoints); want != have {
t.Errorf("want %d, have %d", want, have)
}
if want, have := uint64(2), atomic.LoadUint64(&lookups); want != have {
t.Errorf("want %d, have %d", want, have)
}
if want, have := uint64(len(records)), atomic.LoadUint64(&generates); want != have {
t.Errorf("want %d, have %d", want, have)
}
}
type nopCloser struct{}
func (nopCloser) Close() error { return nil }

5
sd/doc.go Normal file
View File

@ -0,0 +1,5 @@
// Package sd provides utilities related to service discovery. That includes
// subscribing to service discovery systems in order to reach remote instances,
// and publishing to service discovery systems to make an instance available.
// Implementations are provided for most common systems.
package sd

131
sd/etcd/client.go Normal file
View File

@ -0,0 +1,131 @@
package etcd
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net"
"net/http"
"time"
etcd "github.com/coreos/etcd/client"
"golang.org/x/net/context"
)
// Client is a wrapper around the etcd client.
type Client interface {
// GetEntries will query the given prefix in etcd and returns a set of entries.
GetEntries(prefix string) ([]string, error)
// WatchPrefix starts watching every change for given prefix in etcd. When an
// change is detected it will populate the responseChan when an *etcd.Response.
WatchPrefix(prefix string, responseChan chan *etcd.Response)
}
type client struct {
keysAPI etcd.KeysAPI
ctx context.Context
}
// ClientOptions defines options for the etcd client.
type ClientOptions struct {
Cert string
Key string
CaCert string
DialTimeout time.Duration
DialKeepAline time.Duration
HeaderTimeoutPerRequest time.Duration
}
// NewClient returns an *etcd.Client with a connection to the named machines.
// It will return an error if a connection to the cluster cannot be made.
// The parameter machines needs to be a full URL with schemas.
// e.g. "http://localhost:2379" will work, but "localhost:2379" will not.
func NewClient(ctx context.Context, machines []string, options ClientOptions) (Client, error) {
var (
c etcd.KeysAPI
err error
caCertCt []byte
tlsCert tls.Certificate
)
if options.Cert != "" && options.Key != "" {
tlsCert, err = tls.LoadX509KeyPair(options.Cert, options.Key)
if err != nil {
return nil, err
}
caCertCt, err = ioutil.ReadFile(options.CaCert)
if err != nil {
return nil, err
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCertCt)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
RootCAs: caCertPool,
}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
Dial: func(network, addr string) (net.Conn, error) {
dial := &net.Dialer{
Timeout: options.DialTimeout,
KeepAlive: options.DialKeepAline,
}
return dial.Dial(network, addr)
},
}
cfg := etcd.Config{
Endpoints: machines,
Transport: transport,
HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest,
}
ce, err := etcd.New(cfg)
if err != nil {
return nil, err
}
c = etcd.NewKeysAPI(ce)
} else {
cfg := etcd.Config{
Endpoints: machines,
Transport: etcd.DefaultTransport,
HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest,
}
ce, err := etcd.New(cfg)
if err != nil {
return nil, err
}
c = etcd.NewKeysAPI(ce)
}
return &client{c, ctx}, nil
}
// GetEntries implements the etcd Client interface.
func (c *client) GetEntries(key string) ([]string, error) {
resp, err := c.keysAPI.Get(c.ctx, key, &etcd.GetOptions{Recursive: true})
if err != nil {
return nil, err
}
entries := make([]string, len(resp.Node.Nodes))
for i, node := range resp.Node.Nodes {
entries[i] = node.Value
}
return entries, nil
}
// WatchPrefix implements the etcd Client interface.
func (c *client) WatchPrefix(prefix string, responseChan chan *etcd.Response) {
watch := c.keysAPI.Watcher(prefix, &etcd.WatcherOptions{AfterIndex: 0, Recursive: true})
for {
res, err := watch.Next(c.ctx)
if err != nil {
return
}
responseChan <- res
}
}

74
sd/etcd/subscriber.go Normal file
View File

@ -0,0 +1,74 @@
package etcd
import (
etcd "github.com/coreos/etcd/client"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/sd"
"github.com/go-kit/kit/sd/cache"
)
// Subscriber yield endpoints stored in a certain etcd keyspace. Any kind of
// change in that keyspace is watched and will update the Subscriber endpoints.
type Subscriber struct {
client Client
prefix string
cache *cache.Cache
logger log.Logger
quitc chan struct{}
}
var _ sd.Subscriber = &Subscriber{}
// NewSubscriber returns an etcd subscriber. It will start watching the given
// prefix for changes, and update the endpoints.
func NewSubscriber(c Client, prefix string, factory sd.Factory, logger log.Logger) (*Subscriber, error) {
s := &Subscriber{
client: c,
prefix: prefix,
cache: cache.New(factory, logger),
logger: logger,
quitc: make(chan struct{}),
}
instances, err := s.client.GetEntries(s.prefix)
if err == nil {
logger.Log("prefix", s.prefix, "instances", len(instances))
} else {
logger.Log("prefix", s.prefix, "err", err)
}
s.cache.Update(instances)
go s.loop()
return s, nil
}
func (s *Subscriber) loop() {
responseChan := make(chan *etcd.Response)
go s.client.WatchPrefix(s.prefix, responseChan)
for {
select {
case <-responseChan:
instances, err := s.client.GetEntries(s.prefix)
if err != nil {
s.logger.Log("msg", "failed to retrieve entries", "err", err)
continue
}
s.cache.Update(instances)
case <-s.quitc:
return
}
}
}
// Endpoints implements the Subscriber interface.
func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) {
return s.cache.Endpoints(), nil
}
// Stop terminates the Subscriber.
func (s *Subscriber) Stop() {
close(s.quitc)
}

View File

@ -0,0 +1,89 @@
package etcd
import (
"errors"
"io"
"testing"
stdetcd "github.com/coreos/etcd/client"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
)
var (
node = &stdetcd.Node{
Key: "/foo",
Nodes: []*stdetcd.Node{
{Key: "/foo/1", Value: "1:1"},
{Key: "/foo/2", Value: "1:2"},
},
}
fakeResponse = &stdetcd.Response{
Node: node,
}
)
func TestSubscriber(t *testing.T) {
factory := func(string) (endpoint.Endpoint, io.Closer, error) {
return endpoint.Nop, nil, nil
}
client := &fakeClient{
responses: map[string]*stdetcd.Response{"/foo": fakeResponse},
}
s, err := NewSubscriber(client, "/foo", factory, log.NewNopLogger())
if err != nil {
t.Fatal(err)
}
defer s.Stop()
if _, err := s.Endpoints(); err != nil {
t.Fatal(err)
}
}
func TestBadFactory(t *testing.T) {
factory := func(string) (endpoint.Endpoint, io.Closer, error) {
return nil, nil, errors.New("kaboom")
}
client := &fakeClient{
responses: map[string]*stdetcd.Response{"/foo": fakeResponse},
}
s, err := NewSubscriber(client, "/foo", factory, log.NewNopLogger())
if err != nil {
t.Fatal(err)
}
defer s.Stop()
endpoints, err := s.Endpoints()
if err != nil {
t.Fatal(err)
}
if want, have := 0, len(endpoints); want != have {
t.Errorf("want %d, have %d", want, have)
}
}
type fakeClient struct {
responses map[string]*stdetcd.Response
}
func (c *fakeClient) GetEntries(prefix string) ([]string, error) {
response, ok := c.responses[prefix]
if !ok {
return nil, errors.New("key not exist")
}
entries := make([]string, len(response.Node.Nodes))
for i, node := range response.Node.Nodes {
entries[i] = node.Value
}
return entries, nil
}
func (c *fakeClient) WatchPrefix(prefix string, responseChan chan *stdetcd.Response) {}

16
sd/factory.go Normal file
View File

@ -0,0 +1,16 @@
package sd
import (
"io"
"github.com/go-kit/kit/endpoint"
)
// Factory is a function that converts an instance string (e.g. host:port) to a
// specific endpoint. Instances that provide multiple endpoints require multiple
// factories. A factory also returns an io.Closer that's invoked when the
// instance goes away and needs to be cleaned up.
//
// Users are expected to provide their own factory functions that assume
// specific transports, or can deduce transports by parsing the instance string.
type Factory func(instance string) (endpoint.Endpoint, io.Closer, error)

9
sd/fixed_subscriber.go Normal file
View File

@ -0,0 +1,9 @@
package sd
import "github.com/go-kit/kit/endpoint"
// FixedSubscriber yields a fixed set of services.
type FixedSubscriber []endpoint.Endpoint
// Endpoints implements Subscriber.
func (s FixedSubscriber) Endpoints() ([]endpoint.Endpoint, error) { return s, nil }

15
sd/lb/balancer.go Normal file
View File

@ -0,0 +1,15 @@
package lb
import (
"errors"
"github.com/go-kit/kit/endpoint"
)
// Balancer yields endpoints according to some heuristic.
type Balancer interface {
Endpoint() (endpoint.Endpoint, error)
}
// ErrNoEndpoints is returned when no qualifying endpoints are available.
var ErrNoEndpoints = errors.New("no endpoints available")

5
sd/lb/doc.go Normal file
View File

@ -0,0 +1,5 @@
// Package lb deals with client-side load balancing across multiple identical
// instances of services and endpoints. When combined with a service discovery
// system of record, it enables a more decentralized architecture, removing the
// need for separate load balancers like HAProxy.
package lb

32
sd/lb/random.go Normal file
View File

@ -0,0 +1,32 @@
package lb
import (
"math/rand"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/sd"
)
// NewRandom returns a load balancer that selects services randomly.
func NewRandom(s sd.Subscriber, seed int64) Balancer {
return &random{
s: s,
r: rand.New(rand.NewSource(seed)),
}
}
type random struct {
s sd.Subscriber
r *rand.Rand
}
func (r *random) Endpoint() (endpoint.Endpoint, error) {
endpoints, err := r.s.Endpoints()
if err != nil {
return nil, err
}
if len(endpoints) <= 0 {
return nil, ErrNoEndpoints
}
return endpoints[r.r.Intn(len(endpoints))], nil
}

52
sd/lb/random_test.go Normal file
View File

@ -0,0 +1,52 @@
package lb
import (
"math"
"testing"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/sd"
"golang.org/x/net/context"
)
func TestRandom(t *testing.T) {
var (
n = 7
endpoints = make([]endpoint.Endpoint, n)
counts = make([]int, n)
seed = int64(12345)
iterations = 1000000
want = iterations / n
tolerance = want / 100 // 1%
)
for i := 0; i < n; i++ {
i0 := i
endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil }
}
subscriber := sd.FixedSubscriber(endpoints)
balancer := NewRandom(subscriber, seed)
for i := 0; i < iterations; i++ {
endpoint, _ := balancer.Endpoint()
endpoint(context.Background(), struct{}{})
}
for i, have := range counts {
delta := int(math.Abs(float64(want - have)))
if delta > tolerance {
t.Errorf("%d: want %d, have %d, delta %d > %d tolerance", i, want, have, delta, tolerance)
}
}
}
func TestRandomNoEndpoints(t *testing.T) {
subscriber := sd.FixedSubscriber{}
balancer := NewRandom(subscriber, 1415926)
_, err := balancer.Endpoint()
if want, have := ErrNoEndpoints, err; want != have {
t.Errorf("want %v, have %v", want, have)
}
}

57
sd/lb/retry.go Normal file
View File

@ -0,0 +1,57 @@
package lb
import (
"fmt"
"strings"
"time"
"golang.org/x/net/context"
"github.com/go-kit/kit/endpoint"
)
// Retry wraps a service load balancer and returns an endpoint oriented load
// balancer for the specified service method.
// Requests to the endpoint will be automatically load balanced via the load
// balancer. Requests that return errors will be retried until they succeed,
// up to max times, or until the timeout is elapsed, whichever comes first.
func Retry(max int, timeout time.Duration, b Balancer) endpoint.Endpoint {
if b == nil {
panic("nil Balancer")
}
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
var (
newctx, cancel = context.WithTimeout(ctx, timeout)
responses = make(chan interface{}, 1)
errs = make(chan error, 1)
a = []string{}
)
defer cancel()
for i := 1; i <= max; i++ {
go func() {
e, err := b.Endpoint()
if err != nil {
errs <- err
return
}
response, err := e(newctx, request)
if err != nil {
errs <- err
return
}
responses <- response
}()
select {
case <-newctx.Done():
return nil, newctx.Err()
case response := <-responses:
return response, nil
case err := <-errs:
a = append(a, err.Error())
continue
}
}
return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; "))
}
}

90
sd/lb/retry_test.go Normal file
View File

@ -0,0 +1,90 @@
package lb_test
import (
"errors"
"testing"
"time"
"golang.org/x/net/context"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/sd"
loadbalancer "github.com/go-kit/kit/sd/lb"
)
func TestRetryMaxTotalFail(t *testing.T) {
var (
endpoints = sd.FixedSubscriber{} // no endpoints
lb = loadbalancer.NewRoundRobin(endpoints)
retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries
ctx = context.Background()
)
if _, err := retry(ctx, struct{}{}); err == nil {
t.Errorf("expected error, got none") // should fail
}
}
func TestRetryMaxPartialFail(t *testing.T) {
var (
endpoints = []endpoint.Endpoint{
func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
}
subscriber = sd.FixedSubscriber{
0: endpoints[0],
1: endpoints[1],
2: endpoints[2],
}
retries = len(endpoints) - 1 // not quite enough retries
lb = loadbalancer.NewRoundRobin(subscriber)
ctx = context.Background()
)
if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil {
t.Errorf("expected error, got none")
}
}
func TestRetryMaxSuccess(t *testing.T) {
var (
endpoints = []endpoint.Endpoint{
func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
}
subscriber = sd.FixedSubscriber{
0: endpoints[0],
1: endpoints[1],
2: endpoints[2],
}
retries = len(endpoints) // exactly enough retries
lb = loadbalancer.NewRoundRobin(subscriber)
ctx = context.Background()
)
if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil {
t.Error(err)
}
}
func TestRetryTimeout(t *testing.T) {
var (
step = make(chan struct{})
e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
timeout = time.Millisecond
retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(sd.FixedSubscriber{0: e}))
errs = make(chan error, 1)
invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
)
go func() { step <- struct{}{} }() // queue up a flush of the endpoint
invoke() // invoke the endpoint and trigger the flush
if err := <-errs; err != nil { // that should succeed
t.Error(err)
}
go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush
invoke() // invoke the endpoint
if err := <-errs; err != context.DeadlineExceeded { // that should not succeed
t.Errorf("wanted %v, got none", context.DeadlineExceeded)
}
}

34
sd/lb/round_robin.go Normal file
View File

@ -0,0 +1,34 @@
package lb
import (
"sync/atomic"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/sd"
)
// NewRoundRobin returns a load balancer that returns services in sequence.
func NewRoundRobin(s sd.Subscriber) Balancer {
return &roundRobin{
s: s,
c: 0,
}
}
type roundRobin struct {
s sd.Subscriber
c uint64
}
func (rr *roundRobin) Endpoint() (endpoint.Endpoint, error) {
endpoints, err := rr.s.Endpoints()
if err != nil {
return nil, err
}
if len(endpoints) <= 0 {
return nil, ErrNoEndpoints
}
old := atomic.AddUint64(&rr.c, 1) - 1
idx := old % uint64(len(endpoints))
return endpoints[idx], nil
}

96
sd/lb/round_robin_test.go Normal file
View File

@ -0,0 +1,96 @@
package lb
import (
"reflect"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/net/context"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/sd"
)
func TestRoundRobin(t *testing.T) {
var (
counts = []int{0, 0, 0}
endpoints = []endpoint.Endpoint{
func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
}
)
subscriber := sd.FixedSubscriber(endpoints)
balancer := NewRoundRobin(subscriber)
for i, want := range [][]int{
{1, 0, 0},
{1, 1, 0},
{1, 1, 1},
{2, 1, 1},
{2, 2, 1},
{2, 2, 2},
{3, 2, 2},
} {
endpoint, err := balancer.Endpoint()
if err != nil {
t.Fatal(err)
}
endpoint(context.Background(), struct{}{})
if have := counts; !reflect.DeepEqual(want, have) {
t.Fatalf("%d: want %v, have %v", i, want, have)
}
}
}
func TestRoundRobinNoEndpoints(t *testing.T) {
subscriber := sd.FixedSubscriber{}
balancer := NewRoundRobin(subscriber)
_, err := balancer.Endpoint()
if want, have := ErrNoEndpoints, err; want != have {
t.Errorf("want %v, have %v", want, have)
}
}
func TestRoundRobinNoRace(t *testing.T) {
balancer := NewRoundRobin(sd.FixedSubscriber([]endpoint.Endpoint{
endpoint.Nop,
endpoint.Nop,
endpoint.Nop,
endpoint.Nop,
endpoint.Nop,
}))
var (
n = 100
done = make(chan struct{})
wg sync.WaitGroup
count uint64
)
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
for {
select {
case <-done:
return
default:
_, _ = balancer.Endpoint()
atomic.AddUint64(&count, 1)
}
}
}()
}
time.Sleep(time.Second)
close(done)
wg.Wait()
t.Logf("made %d calls", atomic.LoadUint64(&count))
}

13
sd/registrar.go Normal file
View File

@ -0,0 +1,13 @@
package sd
// Registrar registers instance information to a service discovery system when
// an instance becomes alive and healthy, and deregisters that information when
// the service becomes unhealthy or goes away.
//
// Registrar implementations exist for various service discovery systems. Note
// that identifying instance information (e.g. host:port) must be given via the
// concrete constructor; this interface merely signals lifecycle changes.
type Registrar interface {
Register()
Deregister()
}

11
sd/subscriber.go Normal file
View File

@ -0,0 +1,11 @@
package sd
import "github.com/go-kit/kit/endpoint"
// Subscriber listens to a service discovery system and yields a set of
// identical endpoints on demand. An error indicates a problem with connectivity
// to the service discovery system, or within the system itself; a subscriber
// may yield no endpoints without error.
type Subscriber interface {
Endpoints() ([]endpoint.Endpoint, error)
}

231
sd/zk/client.go Normal file
View File

@ -0,0 +1,231 @@
package zk
import (
"errors"
"net"
"strings"
"time"
"github.com/samuel/go-zookeeper/zk"
"github.com/go-kit/kit/log"
)
// DefaultACL is the default ACL to use for creating znodes.
var (
DefaultACL = zk.WorldACL(zk.PermAll)
ErrInvalidCredentials = errors.New("invalid credentials provided")
ErrClientClosed = errors.New("client service closed")
)
const (
// DefaultConnectTimeout is the default timeout to establish a connection to
// a ZooKeeper node.
DefaultConnectTimeout = 2 * time.Second
// DefaultSessionTimeout is the default timeout to keep the current
// ZooKeeper session alive during a temporary disconnect.
DefaultSessionTimeout = 5 * time.Second
)
// Client is a wrapper around a lower level ZooKeeper client implementation.
type Client interface {
// GetEntries should query the provided path in ZooKeeper, place a watch on
// it and retrieve data from its current child nodes.
GetEntries(path string) ([]string, <-chan zk.Event, error)
// CreateParentNodes should try to create the path in case it does not exist
// yet on ZooKeeper.
CreateParentNodes(path string) error
// Stop should properly shutdown the client implementation
Stop()
}
type clientConfig struct {
logger log.Logger
acl []zk.ACL
credentials []byte
connectTimeout time.Duration
sessionTimeout time.Duration
rootNodePayload [][]byte
eventHandler func(zk.Event)
}
// Option functions enable friendly APIs.
type Option func(*clientConfig) error
type client struct {
*zk.Conn
clientConfig
active bool
quit chan struct{}
}
// ACL returns an Option specifying a non-default ACL for creating parent nodes.
func ACL(acl []zk.ACL) Option {
return func(c *clientConfig) error {
c.acl = acl
return nil
}
}
// Credentials returns an Option specifying a user/password combination which
// the client will use to authenticate itself with.
func Credentials(user, pass string) Option {
return func(c *clientConfig) error {
if user == "" || pass == "" {
return ErrInvalidCredentials
}
c.credentials = []byte(user + ":" + pass)
return nil
}
}
// ConnectTimeout returns an Option specifying a non-default connection timeout
// when we try to establish a connection to a ZooKeeper server.
func ConnectTimeout(t time.Duration) Option {
return func(c *clientConfig) error {
if t.Seconds() < 1 {
return errors.New("invalid connect timeout (minimum value is 1 second)")
}
c.connectTimeout = t
return nil
}
}
// SessionTimeout returns an Option specifying a non-default session timeout.
func SessionTimeout(t time.Duration) Option {
return func(c *clientConfig) error {
if t.Seconds() < 1 {
return errors.New("invalid session timeout (minimum value is 1 second)")
}
c.sessionTimeout = t
return nil
}
}
// Payload returns an Option specifying non-default data values for each znode
// created by CreateParentNodes.
func Payload(payload [][]byte) Option {
return func(c *clientConfig) error {
c.rootNodePayload = payload
return nil
}
}
// EventHandler returns an Option specifying a callback function to handle
// incoming zk.Event payloads (ZooKeeper connection events).
func EventHandler(handler func(zk.Event)) Option {
return func(c *clientConfig) error {
c.eventHandler = handler
return nil
}
}
// NewClient returns a ZooKeeper client with a connection to the server cluster.
// It will return an error if the server cluster cannot be resolved.
func NewClient(servers []string, logger log.Logger, options ...Option) (Client, error) {
defaultEventHandler := func(event zk.Event) {
logger.Log("eventtype", event.Type.String(), "server", event.Server, "state", event.State.String(), "err", event.Err)
}
config := clientConfig{
acl: DefaultACL,
connectTimeout: DefaultConnectTimeout,
sessionTimeout: DefaultSessionTimeout,
eventHandler: defaultEventHandler,
logger: logger,
}
for _, option := range options {
if err := option(&config); err != nil {
return nil, err
}
}
// dialer overrides the default ZooKeeper library Dialer so we can configure
// the connectTimeout. The current library has a hardcoded value of 1 second
// and there are reports of race conditions, due to slow DNS resolvers and
// other network latency issues.
dialer := func(network, address string, _ time.Duration) (net.Conn, error) {
return net.DialTimeout(network, address, config.connectTimeout)
}
conn, eventc, err := zk.Connect(servers, config.sessionTimeout, withLogger(logger), zk.WithDialer(dialer))
if err != nil {
return nil, err
}
if len(config.credentials) > 0 {
err = conn.AddAuth("digest", config.credentials)
if err != nil {
return nil, err
}
}
c := &client{conn, config, true, make(chan struct{})}
// Start listening for incoming Event payloads and callback the set
// eventHandler.
go func() {
for {
select {
case event := <-eventc:
config.eventHandler(event)
case <-c.quit:
return
}
}
}()
return c, nil
}
// CreateParentNodes implements the ZooKeeper Client interface.
func (c *client) CreateParentNodes(path string) error {
if !c.active {
return ErrClientClosed
}
if path[0] != '/' {
return zk.ErrInvalidPath
}
payload := []byte("")
pathString := ""
pathNodes := strings.Split(path, "/")
for i := 1; i < len(pathNodes); i++ {
if i <= len(c.rootNodePayload) {
payload = c.rootNodePayload[i-1]
} else {
payload = []byte("")
}
pathString += "/" + pathNodes[i]
_, err := c.Create(pathString, payload, 0, c.acl)
// not being able to create the node because it exists or not having
// sufficient rights is not an issue. It is ok for the node to already
// exist and/or us to only have read rights
if err != nil && err != zk.ErrNodeExists && err != zk.ErrNoAuth {
return err
}
}
return nil
}
// GetEntries implements the ZooKeeper Client interface.
func (c *client) GetEntries(path string) ([]string, <-chan zk.Event, error) {
// retrieve list of child nodes for given path and add watch to path
znodes, _, eventc, err := c.ChildrenW(path)
if err != nil {
return nil, eventc, err
}
var resp []string
for _, znode := range znodes {
// retrieve payload for child znode and add to response array
if data, _, err := c.Get(path + "/" + znode); err == nil {
resp = append(resp, string(data))
}
}
return resp, eventc, nil
}
// Stop implements the ZooKeeper Client interface.
func (c *client) Stop() {
c.active = false
close(c.quit)
c.Close()
}

157
sd/zk/client_test.go Normal file
View File

@ -0,0 +1,157 @@
package zk
import (
"bytes"
"testing"
"time"
stdzk "github.com/samuel/go-zookeeper/zk"
"github.com/go-kit/kit/log"
)
func TestNewClient(t *testing.T) {
var (
acl = stdzk.WorldACL(stdzk.PermRead)
connectTimeout = 3 * time.Second
sessionTimeout = 20 * time.Second
payload = [][]byte{[]byte("Payload"), []byte("Test")}
)
c, err := NewClient(
[]string{"FailThisInvalidHost!!!"},
log.NewNopLogger(),
)
if err == nil {
t.Errorf("expected error, got nil")
}
hasFired := false
calledEventHandler := make(chan struct{})
eventHandler := func(event stdzk.Event) {
if !hasFired {
// test is successful if this function has fired at least once
hasFired = true
close(calledEventHandler)
}
}
c, err = NewClient(
[]string{"localhost"},
log.NewNopLogger(),
ACL(acl),
ConnectTimeout(connectTimeout),
SessionTimeout(sessionTimeout),
Payload(payload),
EventHandler(eventHandler),
)
if err != nil {
t.Fatal(err)
}
defer c.Stop()
clientImpl, ok := c.(*client)
if !ok {
t.Fatal("retrieved incorrect Client implementation")
}
if want, have := acl, clientImpl.acl; want[0] != have[0] {
t.Errorf("want %+v, have %+v", want, have)
}
if want, have := connectTimeout, clientImpl.connectTimeout; want != have {
t.Errorf("want %d, have %d", want, have)
}
if want, have := sessionTimeout, clientImpl.sessionTimeout; want != have {
t.Errorf("want %d, have %d", want, have)
}
if want, have := payload, clientImpl.rootNodePayload; bytes.Compare(want[0], have[0]) != 0 || bytes.Compare(want[1], have[1]) != 0 {
t.Errorf("want %s, have %s", want, have)
}
select {
case <-calledEventHandler:
case <-time.After(100 * time.Millisecond):
t.Errorf("event handler never called")
}
}
func TestOptions(t *testing.T) {
_, err := NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("valid", "credentials"))
if err != nil && err != stdzk.ErrNoServer {
t.Errorf("unexpected error: %v", err)
}
_, err = NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("nopass", ""))
if want, have := err, ErrInvalidCredentials; want != have {
t.Errorf("want %v, have %v", want, have)
}
_, err = NewClient([]string{"localhost"}, log.NewNopLogger(), ConnectTimeout(0))
if err == nil {
t.Errorf("expected connect timeout error")
}
_, err = NewClient([]string{"localhost"}, log.NewNopLogger(), SessionTimeout(0))
if err == nil {
t.Errorf("expected connect timeout error")
}
}
func TestCreateParentNodes(t *testing.T) {
payload := [][]byte{[]byte("Payload"), []byte("Test")}
c, err := NewClient([]string{"localhost:65500"}, log.NewNopLogger())
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if c == nil {
t.Fatal("expected new Client, got nil")
}
s, err := NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger())
if err != stdzk.ErrNoServer {
t.Errorf("unexpected error: %v", err)
}
if s != nil {
t.Error("expected failed new Subscriber")
}
s, err = NewSubscriber(c, "invalidpath", newFactory(""), log.NewNopLogger())
if err != stdzk.ErrInvalidPath {
t.Errorf("unexpected error: %v", err)
}
_, _, err = c.GetEntries("/validpath")
if err != stdzk.ErrNoServer {
t.Errorf("unexpected error: %v", err)
}
c.Stop()
err = c.CreateParentNodes("/validpath")
if err != ErrClientClosed {
t.Errorf("unexpected error: %v", err)
}
s, err = NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger())
if err != ErrClientClosed {
t.Errorf("unexpected error: %v", err)
}
if s != nil {
t.Error("expected failed new Subscriber")
}
c, err = NewClient([]string{"localhost:65500"}, log.NewNopLogger(), Payload(payload))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if c == nil {
t.Fatal("expected new Client, got nil")
}
s, err = NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger())
if err != stdzk.ErrNoServer {
t.Errorf("unexpected error: %v", err)
}
if s != nil {
t.Error("expected failed new Subscriber")
}
}

201
sd/zk/integration_test.go Normal file
View File

@ -0,0 +1,201 @@
// +build integration
package zk
import (
"bytes"
"flag"
"fmt"
"os"
"testing"
"time"
stdzk "github.com/samuel/go-zookeeper/zk"
)
var (
host []string
)
func TestMain(m *testing.M) {
flag.Parse()
fmt.Println("Starting ZooKeeper server...")
ts, err := stdzk.StartTestCluster(1, nil, nil)
if err != nil {
fmt.Printf("ZooKeeper server error: %v\n", err)
os.Exit(1)
}
host = []string{fmt.Sprintf("localhost:%d", ts.Servers[0].Port)}
code := m.Run()
ts.Stop()
os.Exit(code)
}
func TestCreateParentNodesOnServer(t *testing.T) {
payload := [][]byte{[]byte("Payload"), []byte("Test")}
c1, err := NewClient(host, logger, Payload(payload))
if err != nil {
t.Fatalf("Connect returned error: %v", err)
}
if c1 == nil {
t.Fatal("Expected pointer to client, got nil")
}
defer c1.Stop()
s, err := NewSubscriber(c1, path, newFactory(""), logger)
if err != nil {
t.Fatalf("Unable to create Subscriber: %v", err)
}
defer s.Stop()
services, err := s.Services()
if err != nil {
t.Fatal(err)
}
if want, have := 0, len(services); want != have {
t.Errorf("want %d, have %d", want, have)
}
c2, err := NewClient(host, logger)
if err != nil {
t.Fatalf("Connect returned error: %v", err)
}
defer c2.Stop()
data, _, err := c2.(*client).Get(path)
if err != nil {
t.Fatal(err)
}
// test Client implementation of CreateParentNodes. It should have created
// our payload
if bytes.Compare(data, payload[1]) != 0 {
t.Errorf("want %s, have %s", payload[1], data)
}
}
func TestCreateBadParentNodesOnServer(t *testing.T) {
c, _ := NewClient(host, logger)
defer c.Stop()
_, err := NewSubscriber(c, "invalid/path", newFactory(""), logger)
if want, have := stdzk.ErrInvalidPath, err; want != have {
t.Errorf("want %v, have %v", want, have)
}
}
func TestCredentials1(t *testing.T) {
acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret")
c, _ := NewClient(host, logger, ACL(acl), Credentials("user", "secret"))
defer c.Stop()
_, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger)
if err != nil {
t.Fatal(err)
}
}
func TestCredentials2(t *testing.T) {
acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret")
c, _ := NewClient(host, logger, ACL(acl))
defer c.Stop()
_, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger)
if err != stdzk.ErrNoAuth {
t.Errorf("want %v, have %v", stdzk.ErrNoAuth, err)
}
}
func TestConnection(t *testing.T) {
c, _ := NewClient(host, logger)
c.Stop()
_, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger)
if err != ErrClientClosed {
t.Errorf("want %v, have %v", ErrClientClosed, err)
}
}
func TestGetEntriesOnServer(t *testing.T) {
var instancePayload = "protocol://hostname:port/routing"
c1, err := NewClient(host, logger)
if err != nil {
t.Fatalf("Connect returned error: %v", err)
}
defer c1.Stop()
c2, err := NewClient(host, logger)
s, err := NewSubscriber(c2, path, newFactory(""), logger)
if err != nil {
t.Fatal(err)
}
defer c2.Stop()
c2impl, _ := c2.(*client)
_, err = c2impl.Create(
path+"/instance1",
[]byte(instancePayload),
stdzk.FlagEphemeral|stdzk.FlagSequence,
stdzk.WorldACL(stdzk.PermAll),
)
if err != nil {
t.Fatalf("Unable to create test ephemeral znode 1: %v", err)
}
_, err = c2impl.Create(
path+"/instance2",
[]byte(instancePayload+"2"),
stdzk.FlagEphemeral|stdzk.FlagSequence,
stdzk.WorldACL(stdzk.PermAll),
)
if err != nil {
t.Fatalf("Unable to create test ephemeral znode 2: %v", err)
}
time.Sleep(50 * time.Millisecond)
services, err := s.Services()
if err != nil {
t.Fatal(err)
}
if want, have := 2, len(services); want != have {
t.Errorf("want %d, have %d", want, have)
}
}
func TestGetEntriesPayloadOnServer(t *testing.T) {
c, err := NewClient(host, logger)
if err != nil {
t.Fatalf("Connect returned error: %v", err)
}
_, eventc, err := c.GetEntries(path)
if err != nil {
t.Fatal(err)
}
_, err = c.(*client).Create(
path+"/instance3",
[]byte("just some payload"),
stdzk.FlagEphemeral|stdzk.FlagSequence,
stdzk.WorldACL(stdzk.PermAll),
)
if err != nil {
t.Fatalf("Unable to create test ephemeral znode: %v", err)
}
select {
case event := <-eventc:
if want, have := stdzk.EventNodeChildrenChanged.String(), event.Type.String(); want != have {
t.Errorf("want %s, have %s", want, have)
}
case <-time.After(20 * time.Millisecond):
t.Errorf("expected incoming watch event, timeout occurred")
}
}

27
sd/zk/logwrapper.go Normal file
View File

@ -0,0 +1,27 @@
package zk
import (
"fmt"
"github.com/samuel/go-zookeeper/zk"
"github.com/go-kit/kit/log"
)
// wrapLogger wraps a Go kit logger so we can use it as the logging service for
// the ZooKeeper library, which expects a Printf method to be available.
type wrapLogger struct {
log.Logger
}
func (logger wrapLogger) Printf(format string, args ...interface{}) {
logger.Log("msg", fmt.Sprintf(format, args...))
}
// withLogger replaces the ZooKeeper library's default logging service with our
// own Go kit logger.
func withLogger(logger log.Logger) func(c *zk.Conn) {
return func(c *zk.Conn) {
c.SetLogger(wrapLogger{logger})
}
}

86
sd/zk/subscriber.go Normal file
View File

@ -0,0 +1,86 @@
package zk
import (
"github.com/samuel/go-zookeeper/zk"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/sd"
"github.com/go-kit/kit/sd/cache"
)
// Subscriber yield endpoints stored in a certain ZooKeeper path. Any kind of
// change in that path is watched and will update the Subscriber endpoints.
type Subscriber struct {
client Client
path string
cache *cache.Cache
logger log.Logger
quitc chan struct{}
}
var _ sd.Subscriber = &Subscriber{}
// NewSubscriber returns a ZooKeeper subscriber. ZooKeeper will start watching
// the given path for changes and update the Subscriber endpoints.
func NewSubscriber(c Client, path string, factory sd.Factory, logger log.Logger) (*Subscriber, error) {
s := &Subscriber{
client: c,
path: path,
cache: cache.New(factory, logger),
logger: logger,
quitc: make(chan struct{}),
}
err := s.client.CreateParentNodes(s.path)
if err != nil {
return nil, err
}
instances, eventc, err := s.client.GetEntries(s.path)
if err != nil {
logger.Log("path", s.path, "msg", "failed to retrieve entries", "err", err)
return nil, err
}
logger.Log("path", s.path, "instances", len(instances))
s.cache.Update(instances)
go s.loop(eventc)
return s, nil
}
func (s *Subscriber) loop(eventc <-chan zk.Event) {
var (
instances []string
err error
)
for {
select {
case <-eventc:
// We received a path update notification. Call GetEntries to
// retrieve child node data, and set a new watch, as ZK watches are
// one-time triggers.
instances, eventc, err = s.client.GetEntries(s.path)
if err != nil {
s.logger.Log("path", s.path, "msg", "failed to retrieve entries", "err", err)
continue
}
s.logger.Log("path", s.path, "instances", len(instances))
s.cache.Update(instances)
case <-s.quitc:
return
}
}
}
// Endpoints implements the Subscriber interface.
func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) {
return s.cache.Endpoints(), nil
}
// Stop terminates the Subscriber.
func (s *Subscriber) Stop() {
close(s.quitc)
}

117
sd/zk/subscriber_test.go Normal file
View File

@ -0,0 +1,117 @@
package zk
import (
"testing"
"time"
)
func TestSubscriber(t *testing.T) {
client := newFakeClient()
s, err := NewSubscriber(client, path, newFactory(""), logger)
if err != nil {
t.Fatalf("failed to create new Subscriber: %v", err)
}
defer s.Stop()
if _, err := s.Endpoints(); err != nil {
t.Fatal(err)
}
}
func TestBadFactory(t *testing.T) {
client := newFakeClient()
s, err := NewSubscriber(client, path, newFactory("kaboom"), logger)
if err != nil {
t.Fatalf("failed to create new Subscriber: %v", err)
}
defer s.Stop()
// instance1 came online
client.AddService(path+"/instance1", "kaboom")
// instance2 came online
client.AddService(path+"/instance2", "zookeeper_node_data")
if err = asyncTest(100*time.Millisecond, 1, s); err != nil {
t.Error(err)
}
}
func TestServiceUpdate(t *testing.T) {
client := newFakeClient()
s, err := NewSubscriber(client, path, newFactory(""), logger)
if err != nil {
t.Fatalf("failed to create new Subscriber: %v", err)
}
defer s.Stop()
endpoints, err := s.Endpoints()
if err != nil {
t.Fatal(err)
}
if want, have := 0, len(endpoints); want != have {
t.Errorf("want %d, have %d", want, have)
}
// instance1 came online
client.AddService(path+"/instance1", "zookeeper_node_data1")
// instance2 came online
client.AddService(path+"/instance2", "zookeeper_node_data2")
// we should have 2 instances
if err = asyncTest(100*time.Millisecond, 2, s); err != nil {
t.Error(err)
}
// TODO(pb): this bit is flaky
//
//// watch triggers an error...
//client.SendErrorOnWatch()
//
//// test if error was consumed
//if err = client.ErrorIsConsumedWithin(100 * time.Millisecond); err != nil {
// t.Error(err)
//}
// instance3 came online
client.AddService(path+"/instance3", "zookeeper_node_data3")
// we should have 3 instances
if err = asyncTest(100*time.Millisecond, 3, s); err != nil {
t.Error(err)
}
// instance1 goes offline
client.RemoveService(path + "/instance1")
// instance2 goes offline
client.RemoveService(path + "/instance2")
// we should have 1 instance
if err = asyncTest(100*time.Millisecond, 1, s); err != nil {
t.Error(err)
}
}
func TestBadSubscriberCreate(t *testing.T) {
client := newFakeClient()
client.SendErrorOnWatch()
s, err := NewSubscriber(client, path, newFactory(""), logger)
if err == nil {
t.Error("expected error on new Subscriber")
}
if s != nil {
t.Error("expected Subscriber not to be created")
}
s, err = NewSubscriber(client, "BadPath", newFactory(""), logger)
if err == nil {
t.Error("expected error on new Subscriber")
}
if s != nil {
t.Error("expected Subscriber not to be created")
}
}

126
sd/zk/util_test.go Normal file
View File

@ -0,0 +1,126 @@
package zk
import (
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/samuel/go-zookeeper/zk"
"golang.org/x/net/context"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/sd"
)
var (
path = "/gokit.test/service.name"
e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
logger = log.NewNopLogger()
)
type fakeClient struct {
mtx sync.Mutex
ch chan zk.Event
responses map[string]string
result bool
}
func newFakeClient() *fakeClient {
return &fakeClient{
ch: make(chan zk.Event, 1),
responses: make(map[string]string),
result: true,
}
}
func (c *fakeClient) CreateParentNodes(path string) error {
if path == "BadPath" {
return errors.New("dummy error")
}
return nil
}
func (c *fakeClient) GetEntries(path string) ([]string, <-chan zk.Event, error) {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.result == false {
c.result = true
return []string{}, c.ch, errors.New("dummy error")
}
responses := []string{}
for _, data := range c.responses {
responses = append(responses, data)
}
return responses, c.ch, nil
}
func (c *fakeClient) AddService(node, data string) {
c.mtx.Lock()
defer c.mtx.Unlock()
c.responses[node] = data
c.ch <- zk.Event{}
}
func (c *fakeClient) RemoveService(node string) {
c.mtx.Lock()
defer c.mtx.Unlock()
delete(c.responses, node)
c.ch <- zk.Event{}
}
func (c *fakeClient) SendErrorOnWatch() {
c.mtx.Lock()
defer c.mtx.Unlock()
c.result = false
c.ch <- zk.Event{}
}
func (c *fakeClient) ErrorIsConsumedWithin(timeout time.Duration) error {
t := time.After(timeout)
for {
select {
case <-t:
return fmt.Errorf("expected error not consumed after timeout %s", timeout)
default:
c.mtx.Lock()
if c.result == false {
c.mtx.Unlock()
return nil
}
c.mtx.Unlock()
}
}
}
func (c *fakeClient) Stop() {}
func newFactory(fakeError string) sd.Factory {
return func(instance string) (endpoint.Endpoint, io.Closer, error) {
if fakeError == instance {
return nil, nil, errors.New(fakeError)
}
return endpoint.Nop, nil, nil
}
}
func asyncTest(timeout time.Duration, want int, s *Subscriber) (err error) {
var endpoints []endpoint.Endpoint
have := -1 // want can never be <0
t := time.After(timeout)
for {
select {
case <-t:
return fmt.Errorf("want %d, have %d (timeout %s)", want, have, timeout.String())
default:
endpoints, err = s.Endpoints()
have = len(endpoints)
if err != nil || want == have {
return
}
time.Sleep(timeout / 10)
}
}
}