1
0
mirror of https://github.com/go-micro/go-micro.git synced 2025-06-12 22:07:47 +02:00

feat: add test framework & refactor RPC server (#2579)

Co-authored-by: Rene Jochum <rene@jochum.dev>
This commit is contained in:
David Brouwer
2022-10-20 13:00:50 +02:00
committed by GitHub
parent c25dee7c8a
commit a3980c2308
54 changed files with 3703 additions and 2497 deletions

View File

@ -39,7 +39,7 @@ jobs:
go get -v -t -d ./... go get -v -t -d ./...
- name: Run tests - name: Run tests
id: tests id: tests
run: richgo test -v -race -cover ./... run: richgo test -v -race -cover -bench=. ./...
env: env:
IN_TRAVIS_CI: yes IN_TRAVIS_CI: yes
RICHGO_FORCE_COLOR: 1 RICHGO_FORCE_COLOR: 1
@ -60,6 +60,6 @@ jobs:
go get -v -t -d ./... go get -v -t -d ./...
- name: Run tests - name: Run tests
id: tests id: tests
run: go test -v -race -cover -json ./... | tparse -notests -format=markdown >> $GITHUB_STEP_SUMMARY run: go test -v -race -cover -json -bench=. ./... | tparse -notests -format=markdown >> $GITHUB_STEP_SUMMARY
env: env:
IN_TRAVIS_CI: yes IN_TRAVIS_CI: yes

4
.gitignore vendored
View File

@ -35,3 +35,7 @@ _cgo_export.*
*~ *~
*.swp *.swp
*.swo *.swo
# go work files
go.work
go.work.sum

View File

@ -57,6 +57,11 @@ output:
# all available settings of specific linters # all available settings of specific linters
linters-settings: linters-settings:
wsl:
allow-cuddle-with-calls: ["Lock", "RLock", "defer"]
funlen:
lines: 80
statements: 60
varnamelen: varnamelen:
# The longest distance, in source lines, that is being considered a "small scope". # The longest distance, in source lines, that is being considered a "small scope".
# Variables used in at most this many lines will be ignored. # Variables used in at most this many lines will be ignored.
@ -184,6 +189,7 @@ linters:
- makezero - makezero
- gofumpt - gofumpt
- nlreturn - nlreturn
- thelper
# Can be considered to be enabled # Can be considered to be enabled
- gochecknoinits - gochecknoinits
@ -197,6 +203,9 @@ linters:
- exhaustruct - exhaustruct
- containedctx - containedctx
- godox - godox
- forcetypeassert
- gci
- lll
issues: issues:
# List of regexps of issue texts to exclude, empty list by default. # List of regexps of issue texts to exclude, empty list by default.

View File

@ -14,7 +14,7 @@ skipStyle:
foreground: lightBlack foreground: lightBlack
passPackageStyle: passPackageStyle:
foreground: green foreground: green
hide: true hide: false
failPackageStyle: failPackageStyle:
bold: true bold: true
foreground: "#821515" foreground: "#821515"

View File

@ -20,6 +20,7 @@ import (
merr "go-micro.dev/v4/errors" merr "go-micro.dev/v4/errors"
"go-micro.dev/v4/registry" "go-micro.dev/v4/registry"
"go-micro.dev/v4/registry/cache" "go-micro.dev/v4/registry/cache"
"go-micro.dev/v4/transport/headers"
maddr "go-micro.dev/v4/util/addr" maddr "go-micro.dev/v4/util/addr"
mnet "go-micro.dev/v4/util/net" mnet "go-micro.dev/v4/util/net"
mls "go-micro.dev/v4/util/tls" mls "go-micro.dev/v4/util/tls"
@ -313,7 +314,7 @@ func (h *httpBroker) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return return
} }
topic := m.Header["Micro-Topic"] topic := m.Header[headers.Message]
// delete(m.Header, ":topic") // delete(m.Header, ":topic")
if len(topic) == 0 { if len(topic) == 0 {
@ -517,7 +518,7 @@ func (h *httpBroker) Publish(topic string, msg *Message, opts ...PublishOption)
m.Header[k] = v m.Header[k] = v
} }
m.Header["Micro-Topic"] = topic m.Header[headers.Message] = topic
// encode the message // encode the message
b, err := h.opts.Codec.Marshal(m) b, err := h.opts.Codec.Marshal(m)

View File

@ -8,7 +8,9 @@ import (
"time" "time"
cache "github.com/patrickmn/go-cache" cache "github.com/patrickmn/go-cache"
"go-micro.dev/v4/metadata" "go-micro.dev/v4/metadata"
"go-micro.dev/v4/transport/headers"
) )
// NewCache returns an initialized cache. // NewCache returns an initialized cache.
@ -38,6 +40,7 @@ func (c *Cache) List() map[string]string {
items := c.cache.Items() items := c.cache.Items()
rsp := make(map[string]string, len(items)) rsp := make(map[string]string, len(items))
for k, v := range items { for k, v := range items {
bytes, _ := json.Marshal(v.Object) bytes, _ := json.Marshal(v.Object)
rsp[k] = string(bytes) rsp[k] = string(bytes)
@ -48,7 +51,7 @@ func (c *Cache) List() map[string]string {
// key returns a hash for the context and request. // key returns a hash for the context and request.
func key(ctx context.Context, req *Request) string { func key(ctx context.Context, req *Request) string {
ns, _ := metadata.Get(ctx, "Micro-Namespace") ns, _ := metadata.Get(ctx, headers.Namespace)
bytes, _ := json.Marshal(map[string]interface{}{ bytes, _ := json.Marshal(map[string]interface{}{
"namespace": ns, "namespace": ns,
@ -62,5 +65,6 @@ func key(ctx context.Context, req *Request) string {
h := fnv.New64() h := fnv.New64()
h.Write(bytes) h.Write(bytes)
return fmt.Sprintf("%x", h.Sum(nil)) return fmt.Sprintf("%x", h.Sum(nil))
} }

View File

@ -6,6 +6,7 @@ import (
"time" "time"
"go-micro.dev/v4/metadata" "go-micro.dev/v4/metadata"
"go-micro.dev/v4/transport/headers"
) )
func TestCache(t *testing.T) { func TestCache(t *testing.T) {
@ -65,7 +66,7 @@ func TestCacheKey(t *testing.T) {
}) })
t.Run("DifferentMetadata", func(t *testing.T) { t.Run("DifferentMetadata", func(t *testing.T) {
mdCtx := metadata.Set(context.TODO(), "Micro-Namespace", "bar") mdCtx := metadata.Set(context.TODO(), headers.Namespace, "bar")
key1 := key(mdCtx, &req1) key1 := key(mdCtx, &req1)
key2 := key(ctx, &req1) key2 := key(ctx, &req1)

View File

@ -3,11 +3,17 @@ package client
import ( import (
"context" "context"
"time"
"go-micro.dev/v4/codec" "go-micro.dev/v4/codec"
) )
var (
// NewClient returns a new client.
NewClient func(...Option) Client = newRPCClient
// DefaultClient is a default client to use out of the box.
DefaultClient Client = newRPCClient()
)
// Client is the interface used to make requests to services. // Client is the interface used to make requests to services.
// It supports Request/Response via Transport and Publishing via the Broker. // It supports Request/Response via Transport and Publishing via the Broker.
// It also supports bidirectional streaming of requests. // It also supports bidirectional streaming of requests.
@ -102,26 +108,6 @@ type MessageOption func(*MessageOptions)
// RequestOption used by NewRequest. // RequestOption used by NewRequest.
type RequestOption func(*RequestOptions) type RequestOption func(*RequestOptions)
var (
// DefaultClient is a default client to use out of the box.
DefaultClient Client = newRpcClient()
// DefaultBackoff is the default backoff function for retries.
DefaultBackoff = exponentialBackoff
// DefaultRetry is the default check-for-retry function for retries.
DefaultRetry = RetryOnError
// DefaultRetries is the default number of times a request is tried.
DefaultRetries = 1
// DefaultRequestTimeout is the default request timeout.
DefaultRequestTimeout = time.Second * 5
// DefaultPoolSize sets the connection pool size.
DefaultPoolSize = 100
// DefaultPoolTTL sets the connection pool ttl.
DefaultPoolTTL = time.Minute
// NewClient returns a new client.
NewClient func(...Option) Client = newRpcClient
)
// Makes a synchronous call to a service using the default client. // Makes a synchronous call to a service using the default client.
func Call(ctx context.Context, request Request, response interface{}, opts ...CallOption) error { func Call(ctx context.Context, request Request, response interface{}, opts ...CallOption) error {
return DefaultClient.Call(ctx, request, response, opts...) return DefaultClient.Call(ctx, request, response, opts...)

View File

@ -12,6 +12,24 @@ import (
"go-micro.dev/v4/transport" "go-micro.dev/v4/transport"
) )
var (
// DefaultBackoff is the default backoff function for retries.
DefaultBackoff = exponentialBackoff
// DefaultRetry is the default check-for-retry function for retries.
DefaultRetry = RetryOnError
// DefaultRetries is the default number of times a request is tried.
DefaultRetries = 5
// DefaultRequestTimeout is the default request timeout.
DefaultRequestTimeout = time.Second * 30
// DefaultConnectionTimeout is the default connection timeout.
DefaultConnectionTimeout = time.Second * 5
// DefaultPoolSize sets the connection pool size.
DefaultPoolSize = 100
// DefaultPoolTTL sets the connection pool ttl.
DefaultPoolTTL = time.Minute
)
// Options are the Client options.
type Options struct { type Options struct {
// Used to select codec // Used to select codec
ContentType string ContentType string
@ -47,6 +65,7 @@ type Options struct {
Context context.Context Context context.Context
} }
// CallOptions are options used to make calls to a server.
type CallOptions struct { type CallOptions struct {
SelectOptions []selector.SelectOption SelectOptions []selector.SelectOption
@ -56,11 +75,14 @@ type CallOptions struct {
Backoff BackoffFunc Backoff BackoffFunc
// Check if retriable func // Check if retriable func
Retry RetryFunc Retry RetryFunc
// Transport Dial Timeout
DialTimeout time.Duration
// Number of Call attempts // Number of Call attempts
Retries int Retries int
// Request/Response timeout // Transport Dial Timeout. Used for initial dial to establish a connection.
DialTimeout time.Duration
// ConnectionTimeout of one request to the server.
// Set this lower than the RequestTimeout to enbale retries on connection timeout.
ConnectionTimeout time.Duration
// Request/Response timeout of entire srv.Call, for single request timeout set ConnectionTimeout.
RequestTimeout time.Duration RequestTimeout time.Duration
// Stream timeout for the stream // Stream timeout for the stream
StreamTimeout time.Duration StreamTimeout time.Duration
@ -68,6 +90,8 @@ type CallOptions struct {
ServiceToken bool ServiceToken bool
// Duration to cache the response for // Duration to cache the response for
CacheExpiry time.Duration CacheExpiry time.Duration
// ConnClose sets the Connection: close header.
ConnClose bool
// Middleware for low level call func // Middleware for low level call func
CallWrappers []CallWrapper CallWrappers []CallWrapper
@ -98,6 +122,7 @@ type RequestOptions struct {
Context context.Context Context context.Context
} }
// NewOptions creates new Client options.
func NewOptions(options ...Option) Options { func NewOptions(options ...Option) Options {
opts := Options{ opts := Options{
Cache: NewCache(), Cache: NewCache(),
@ -105,11 +130,12 @@ func NewOptions(options ...Option) Options {
ContentType: DefaultContentType, ContentType: DefaultContentType,
Codecs: make(map[string]codec.NewCodec), Codecs: make(map[string]codec.NewCodec),
CallOptions: CallOptions{ CallOptions: CallOptions{
Backoff: DefaultBackoff, Backoff: DefaultBackoff,
Retry: DefaultRetry, Retry: DefaultRetry,
Retries: DefaultRetries, Retries: DefaultRetries,
RequestTimeout: DefaultRequestTimeout, RequestTimeout: DefaultRequestTimeout,
DialTimeout: transport.DefaultDialTimeout, ConnectionTimeout: DefaultConnectionTimeout,
DialTimeout: transport.DefaultDialTimeout,
}, },
PoolSize: DefaultPoolSize, PoolSize: DefaultPoolSize,
PoolTTL: DefaultPoolTTL, PoolTTL: DefaultPoolTTL,
@ -141,7 +167,7 @@ func Codec(contentType string, c codec.NewCodec) Option {
} }
} }
// Default content type of the client. // ContentType sets the default content type of the client.
func ContentType(ct string) Option { func ContentType(ct string) Option {
return func(o *Options) { return func(o *Options) {
o.ContentType = ct o.ContentType = ct
@ -207,8 +233,7 @@ func Backoff(fn BackoffFunc) Option {
} }
} }
// Number of retries when making the request. // Retries set the number of retries when making the request.
// Should this be a Call Option?
func Retries(i int) Option { func Retries(i int) Option {
return func(o *Options) { return func(o *Options) {
o.CallOptions.Retries = i o.CallOptions.Retries = i
@ -222,8 +247,7 @@ func Retry(fn RetryFunc) Option {
} }
} }
// The request timeout. // RequestTimeout set the request timeout.
// Should this be a Call Option?
func RequestTimeout(d time.Duration) Option { func RequestTimeout(d time.Duration) Option {
return func(o *Options) { return func(o *Options) {
o.CallOptions.RequestTimeout = d o.CallOptions.RequestTimeout = d
@ -237,7 +261,7 @@ func StreamTimeout(d time.Duration) Option {
} }
} }
// Transport dial timeout. // DialTimeout sets the transport dial timeout.
func DialTimeout(d time.Duration) Option { func DialTimeout(d time.Duration) Option {
return func(o *Options) { return func(o *Options) {
o.CallOptions.DialTimeout = d o.CallOptions.DialTimeout = d
@ -296,8 +320,8 @@ func WithRetry(fn RetryFunc) CallOption {
} }
} }
// WithRetries is a CallOption which overrides that which // WithRetries sets the number of tries for a call.
// set in Options.CallOptions. // This CallOption overrides Options.CallOptions.
func WithRetries(i int) CallOption { func WithRetries(i int) CallOption {
return func(o *CallOptions) { return func(o *CallOptions) {
o.Retries = i o.Retries = i
@ -312,6 +336,13 @@ func WithRequestTimeout(d time.Duration) CallOption {
} }
} }
// WithConnClose sets the Connection header to close.
func WithConnClose() CallOption {
return func(o *CallOptions) {
o.ConnClose = true
}
}
// WithStreamTimeout sets the stream timeout. // WithStreamTimeout sets the stream timeout.
func WithStreamTimeout(d time.Duration) CallOption { func WithStreamTimeout(d time.Duration) CallOption {
return func(o *CallOptions) { return func(o *CallOptions) {

View File

@ -26,8 +26,9 @@ func RetryOnError(ctx context.Context, req Request, retryCount int, err error) (
} }
switch e.Code { switch e.Code {
// retry on timeout or internal server error // Retry on timeout, not on 500 internal server error, as that is a business
case 408, 500: // logic error that should be handled by the user.
case 408:
return true, nil return true, nil
default: default:
return false, nil return false, nil

View File

@ -3,31 +3,42 @@ package client
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/errors"
"go-micro.dev/v4/broker" "go-micro.dev/v4/broker"
"go-micro.dev/v4/codec" "go-micro.dev/v4/codec"
raw "go-micro.dev/v4/codec/bytes" raw "go-micro.dev/v4/codec/bytes"
"go-micro.dev/v4/errors" merrors "go-micro.dev/v4/errors"
log "go-micro.dev/v4/logger"
"go-micro.dev/v4/metadata" "go-micro.dev/v4/metadata"
"go-micro.dev/v4/registry" "go-micro.dev/v4/registry"
"go-micro.dev/v4/selector" "go-micro.dev/v4/selector"
"go-micro.dev/v4/transport" "go-micro.dev/v4/transport"
"go-micro.dev/v4/transport/headers"
"go-micro.dev/v4/util/buf" "go-micro.dev/v4/util/buf"
"go-micro.dev/v4/util/net" "go-micro.dev/v4/util/net"
"go-micro.dev/v4/util/pool" "go-micro.dev/v4/util/pool"
) )
const (
packageID = "go.micro.client"
)
type rpcClient struct { type rpcClient struct {
seq uint64 seq uint64
once atomic.Value once atomic.Value
opts Options opts Options
pool pool.Pool pool pool.Pool
mu sync.RWMutex
} }
func newRpcClient(opt ...Option) Client { func newRPCClient(opt ...Option) Client {
opts := NewOptions(opt...) opts := NewOptions(opt...)
p := pool.NewPool( p := pool.NewPool(
@ -57,14 +68,17 @@ func (r *rpcClient) newCodec(contentType string) (codec.NewCodec, error) {
if c, ok := r.opts.Codecs[contentType]; ok { if c, ok := r.opts.Codecs[contentType]; ok {
return c, nil return c, nil
} }
if cf, ok := DefaultCodecs[contentType]; ok { if cf, ok := DefaultCodecs[contentType]; ok {
return cf, nil return cf, nil
} }
return nil, fmt.Errorf("unsupported Content-Type: %s", contentType) return nil, fmt.Errorf("unsupported Content-Type: %s", contentType)
} }
func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, resp interface{}, opts CallOptions) error { func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, resp interface{}, opts CallOptions) error {
address := node.Address address := node.Address
logger := r.Options().Logger
msg := &transport.Message{ msg := &transport.Message{
Header: make(map[string]string), Header: make(map[string]string),
@ -73,31 +87,43 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request,
md, ok := metadata.FromContext(ctx) md, ok := metadata.FromContext(ctx)
if ok { if ok {
for k, v := range md { for k, v := range md {
// don't copy Micro-Topic header, that used for pub/sub // Don't copy Micro-Topic header, that is used for pub/sub
// this fix case then client uses the same context that received in subscriber // this is fixes the case when the client uses the same context that
if k == "Micro-Topic" { // is received in the subscriber.
if k == headers.Message {
continue continue
} }
msg.Header[k] = v msg.Header[k] = v
} }
} }
// Set connection timeout for single requests to the server. Should be > 0
// as otherwise requests can't be made.
cTimeout := opts.ConnectionTimeout
if cTimeout == 0 {
logger.Log(log.DebugLevel, "connection timeout was set to 0, overridng to default connection timeout")
cTimeout = DefaultConnectionTimeout
}
// set timeout in nanoseconds // set timeout in nanoseconds
msg.Header["Timeout"] = fmt.Sprintf("%d", opts.RequestTimeout) msg.Header["Timeout"] = fmt.Sprintf("%d", cTimeout)
// set the content type for the request // set the content type for the request
msg.Header["Content-Type"] = req.ContentType() msg.Header["Content-Type"] = req.ContentType()
// set the accept header // set the accept header
msg.Header["Accept"] = req.ContentType() msg.Header["Accept"] = req.ContentType()
// setup old protocol // setup old protocol
cf := setupProtocol(msg, node) reqCodec := setupProtocol(msg, node)
// no codec specified // no codec specified
if cf == nil { if reqCodec == nil {
var err error var err error
cf, err = r.newCodec(req.ContentType()) reqCodec, err = r.newCodec(req.ContentType())
if err != nil { if err != nil {
return errors.InternalServerError("go.micro.client", err.Error()) return merrors.InternalServerError("go.micro.client", err.Error())
} }
} }
@ -109,19 +135,29 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request,
dOpts = append(dOpts, transport.WithTimeout(opts.DialTimeout)) dOpts = append(dOpts, transport.WithTimeout(opts.DialTimeout))
} }
if opts.ConnClose {
dOpts = append(dOpts, transport.WithConnClose())
}
c, err := r.pool.Get(address, dOpts...) c, err := r.pool.Get(address, dOpts...)
if err != nil { if err != nil {
return errors.InternalServerError("go.micro.client", "connection error: %v", err) return merrors.InternalServerError("go.micro.client", "connection error: %v", err)
} }
seq := atomic.AddUint64(&r.seq, 1) - 1 seq := atomic.AddUint64(&r.seq, 1) - 1
codec := newRpcCodec(msg, c, cf, "") codec := newRPCCodec(msg, c, reqCodec, "")
rsp := &rpcResponse{ rsp := &rpcResponse{
socket: c, socket: c,
codec: codec, codec: codec,
} }
releaseFunc := func(err error) {
if err = r.pool.Release(c, err); err != nil {
logger.Log(log.ErrorLevel, "failed to release pool", err)
}
}
stream := &rpcStream{ stream := &rpcStream{
id: fmt.Sprintf("%v", seq), id: fmt.Sprintf("%v", seq),
context: ctx, context: ctx,
@ -129,11 +165,17 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request,
response: rsp, response: rsp,
codec: codec, codec: codec,
closed: make(chan bool), closed: make(chan bool),
release: func(err error) { r.pool.Release(c, err) }, close: opts.ConnClose,
release: releaseFunc,
sendEOS: false, sendEOS: false,
} }
// close the stream on exiting this function // close the stream on exiting this function
defer stream.Close() defer func() {
if err := stream.Close(); err != nil {
logger.Log(log.ErrorLevel, "failed to close stream", err)
}
}()
// wait for error response // wait for error response
ch := make(chan error, 1) ch := make(chan error, 1)
@ -141,7 +183,7 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request,
go func() { go func() {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
ch <- errors.InternalServerError("go.micro.client", "panic recovered: %v", r) ch <- merrors.InternalServerError("go.micro.client", "panic recovered: %v", r)
} }
}() }()
@ -166,8 +208,8 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request,
select { select {
case err := <-ch: case err := <-ch:
return err return err
case <-ctx.Done(): case <-time.After(cTimeout):
grr = errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) grr = merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err()))
} }
// set the stream error // set the stream error
@ -184,6 +226,7 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request,
func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request, opts CallOptions) (Stream, error) { func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request, opts CallOptions) (Stream, error) {
address := node.Address address := node.Address
logger := r.Options().Logger
msg := &transport.Message{ msg := &transport.Message{
Header: make(map[string]string), Header: make(map[string]string),
@ -206,14 +249,15 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request
msg.Header["Accept"] = req.ContentType() msg.Header["Accept"] = req.ContentType()
// set old codecs // set old codecs
cf := setupProtocol(msg, node) nCodec := setupProtocol(msg, node)
// no codec specified // no codec specified
if cf == nil { if nCodec == nil {
var err error var err error
cf, err = r.newCodec(req.ContentType())
nCodec, err = r.newCodec(req.ContentType())
if err != nil { if err != nil {
return nil, errors.InternalServerError("go.micro.client", err.Error()) return nil, merrors.InternalServerError("go.micro.client", err.Error())
} }
} }
@ -227,7 +271,7 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request
c, err := r.opts.Transport.Dial(address, dOpts...) c, err := r.opts.Transport.Dial(address, dOpts...)
if err != nil { if err != nil {
return nil, errors.InternalServerError("go.micro.client", "connection error: %v", err) return nil, merrors.InternalServerError("go.micro.client", "connection error: %v", err)
} }
// increment the sequence number // increment the sequence number
@ -235,7 +279,7 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request
id := fmt.Sprintf("%v", seq) id := fmt.Sprintf("%v", seq)
// create codec with stream id // create codec with stream id
codec := newRpcCodec(msg, c, cf, id) codec := newRPCCodec(msg, c, nCodec, id)
rsp := &rpcResponse{ rsp := &rpcResponse{
socket: c, socket: c,
@ -247,6 +291,12 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request
r.codec = codec r.codec = codec
} }
releaseFunc := func(_ error) {
if err = c.Close(); err != nil {
logger.Log(log.ErrorLevel, err)
}
}
stream := &rpcStream{ stream := &rpcStream{
id: id, id: id,
context: ctx, context: ctx,
@ -257,8 +307,7 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request
closed: make(chan bool), closed: make(chan bool),
// signal the end of stream, // signal the end of stream,
sendEOS: true, sendEOS: true,
// release func release: releaseFunc,
release: func(err error) { c.Close() },
} }
// wait for error response // wait for error response
@ -275,7 +324,7 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request
case err := <-ch: case err := <-ch:
grr = err grr = err
case <-ctx.Done(): case <-ctx.Done():
grr = errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) grr = merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err()))
} }
if grr != nil { if grr != nil {
@ -285,7 +334,10 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request
stream.Unlock() stream.Unlock()
// close the stream // close the stream
stream.Close() if err := stream.Close(); err != nil {
logger.Logf(log.ErrorLevel, "failed to close stream: %v", err)
}
return nil, grr return nil, grr
} }
@ -293,6 +345,9 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request
} }
func (r *rpcClient) Init(opts ...Option) error { func (r *rpcClient) Init(opts ...Option) error {
r.mu.Lock()
defer r.mu.Unlock()
size := r.opts.PoolSize size := r.opts.PoolSize
ttl := r.opts.PoolTTL ttl := r.opts.PoolTTL
tr := r.opts.Transport tr := r.opts.Transport
@ -304,7 +359,10 @@ func (r *rpcClient) Init(opts ...Option) error {
// update pool configuration if the options changed // update pool configuration if the options changed
if size != r.opts.PoolSize || ttl != r.opts.PoolTTL || tr != r.opts.Transport { if size != r.opts.PoolSize || ttl != r.opts.PoolTTL || tr != r.opts.Transport {
// close existing pool // close existing pool
r.pool.Close() if err := r.pool.Close(); err != nil {
return errors.Wrap(err, "failed to close pool")
}
// create new pool // create new pool
r.pool = pool.NewPool( r.pool = pool.NewPool(
pool.Size(r.opts.PoolSize), pool.Size(r.opts.PoolSize),
@ -316,7 +374,11 @@ func (r *rpcClient) Init(opts ...Option) error {
return nil return nil
} }
// Options retrives the options.
func (r *rpcClient) Options() Options { func (r *rpcClient) Options() Options {
r.mu.RLock()
defer r.mu.RUnlock()
return r.opts return r.opts
} }
@ -348,16 +410,22 @@ func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, erro
// get next nodes from the selector // get next nodes from the selector
next, err := r.opts.Selector.Select(service, opts.SelectOptions...) next, err := r.opts.Selector.Select(service, opts.SelectOptions...)
if err != nil { if err != nil {
if err == selector.ErrNotFound { if errors.Is(err, selector.ErrNotFound) {
return nil, errors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error()) return nil, merrors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error())
} }
return nil, errors.InternalServerError("go.micro.client", "error selecting %s node: %s", service, err.Error())
return nil, merrors.InternalServerError("go.micro.client", "error selecting %s node: %s", service, err.Error())
} }
return next, nil return next, nil
} }
func (r *rpcClient) Call(ctx context.Context, request Request, response interface{}, opts ...CallOption) error { func (r *rpcClient) Call(ctx context.Context, request Request, response interface{}, opts ...CallOption) error {
// TODO: further validate these mutex locks. full lock would prevent
// parallel calls. Maybe we can set individual locks for secctions.
r.mu.RLock()
defer r.mu.RUnlock()
// make a copy of call opts // make a copy of call opts
callOpts := r.opts.CallOptions callOpts := r.opts.CallOptions
for _, opt := range opts { for _, opt := range opts {
@ -375,6 +443,7 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
// no deadline so we create a new one // no deadline so we create a new one
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, callOpts.RequestTimeout) ctx, cancel = context.WithTimeout(ctx, callOpts.RequestTimeout)
defer cancel() defer cancel()
} else { } else {
// got a deadline so no need to setup context // got a deadline so no need to setup context
@ -386,7 +455,7 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
// should we noop right here? // should we noop right here?
select { select {
case <-ctx.Done(): case <-ctx.Done():
return errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) return merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err()))
default: default:
} }
@ -403,7 +472,7 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
// call backoff first. Someone may want an initial start delay // call backoff first. Someone may want an initial start delay
t, err := callOpts.Backoff(ctx, request, i) t, err := callOpts.Backoff(ctx, request, i)
if err != nil { if err != nil {
return errors.InternalServerError("go.micro.client", "backoff error: %v", err.Error()) return merrors.InternalServerError("go.micro.client", "backoff error: %v", err.Error())
} }
// only sleep if greater than 0 // only sleep if greater than 0
@ -414,16 +483,19 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
// select next node // select next node
node, err := next() node, err := next()
service := request.Service() service := request.Service()
if err != nil { if err != nil {
if err == selector.ErrNotFound { if errors.Is(err, selector.ErrNotFound) {
return errors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error()) return merrors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error())
} }
return errors.InternalServerError("go.micro.client", "error getting next %s node: %s", service, err.Error())
return merrors.InternalServerError("go.micro.client", "error getting next %s node: %s", service, err.Error())
} }
// make the call // make the call
err = rcall(ctx, node, request, response, callOpts) err = rcall(ctx, node, request, response, callOpts)
r.opts.Selector.Mark(service, node, err) r.opts.Selector.Mark(service, node, err)
return err return err
} }
@ -431,11 +503,13 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
retries := callOpts.Retries retries := callOpts.Retries
// disable retries when using a proxy // disable retries when using a proxy
if _, _, ok := net.Proxy(request.Service(), callOpts.Address); ok { // Note: I don't see why we should disable retries for proxies, so commenting out.
retries = 0 // if _, _, ok := net.Proxy(request.Service(), callOpts.Address); ok {
} // retries = 0
// }
ch := make(chan error, retries+1) ch := make(chan error, retries+1)
var gerr error var gerr error
for i := 0; i <= retries; i++ { for i := 0; i <= retries; i++ {
@ -445,7 +519,7 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
select { select {
case <-ctx.Done(): case <-ctx.Done():
return errors.Timeout("go.micro.client", fmt.Sprintf("call timeout: %v", ctx.Err())) return merrors.Timeout("go.micro.client", fmt.Sprintf("call timeout: %v", ctx.Err()))
case err := <-ch: case err := <-ch:
// if the call succeeded lets bail early // if the call succeeded lets bail early
if err == nil { if err == nil {
@ -461,6 +535,8 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
return err return err
} }
r.opts.Logger.Logf(log.DebugLevel, "Retrying request. Previous attempt failed with: %v", err)
gerr = err gerr = err
} }
} }
@ -469,6 +545,9 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
} }
func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOption) (Stream, error) { func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOption) (Stream, error) {
r.mu.RLock()
defer r.mu.RUnlock()
// make a copy of call opts // make a copy of call opts
callOpts := r.opts.CallOptions callOpts := r.opts.CallOptions
for _, opt := range opts { for _, opt := range opts {
@ -480,10 +559,9 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt
return nil, err return nil, err
} }
// should we noop right here?
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) return nil, merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err()))
default: default:
} }
@ -491,7 +569,7 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt
// call backoff first. Someone may want an initial start delay // call backoff first. Someone may want an initial start delay
t, err := callOpts.Backoff(ctx, request, i) t, err := callOpts.Backoff(ctx, request, i)
if err != nil { if err != nil {
return nil, errors.InternalServerError("go.micro.client", "backoff error: %v", err.Error()) return nil, merrors.InternalServerError("go.micro.client", "backoff error: %v", err.Error())
} }
// only sleep if greater than 0 // only sleep if greater than 0
@ -501,15 +579,18 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt
node, err := next() node, err := next()
service := request.Service() service := request.Service()
if err != nil { if err != nil {
if err == selector.ErrNotFound { if errors.Is(err, selector.ErrNotFound) {
return nil, errors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error()) return nil, merrors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error())
} }
return nil, errors.InternalServerError("go.micro.client", "error getting next %s node: %s", service, err.Error())
return nil, merrors.InternalServerError("go.micro.client", "error getting next %s node: %s", service, err.Error())
} }
stream, err := r.stream(ctx, node, request, callOpts) stream, err := r.stream(ctx, node, request, callOpts)
r.opts.Selector.Mark(service, node, err) r.opts.Selector.Mark(service, node, err)
return stream, err return stream, err
} }
@ -527,6 +608,7 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt
} }
ch := make(chan response, retries+1) ch := make(chan response, retries+1)
var grr error var grr error
for i := 0; i <= retries; i++ { for i := 0; i <= retries; i++ {
@ -537,7 +619,7 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, errors.Timeout("go.micro.client", fmt.Sprintf("call timeout: %v", ctx.Err())) return nil, merrors.Timeout("go.micro.client", fmt.Sprintf("call timeout: %v", ctx.Err()))
case rsp := <-ch: case rsp := <-ch:
// if the call succeeded lets bail early // if the call succeeded lets bail early
if rsp.err == nil { if rsp.err == nil {
@ -568,15 +650,15 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt
o(&options) o(&options)
} }
md, ok := metadata.FromContext(ctx) metadata, ok := metadata.FromContext(ctx)
if !ok { if !ok {
md = make(map[string]string) metadata = make(map[string]string)
} }
id := uuid.New().String() id := uuid.New().String()
md["Content-Type"] = msg.ContentType() metadata["Content-Type"] = msg.ContentType()
md["Micro-Topic"] = msg.Topic() metadata[headers.Message] = msg.Topic()
md["Micro-Id"] = id metadata[headers.ID] = id
// set the topic // set the topic
topic := msg.Topic() topic := msg.Topic()
@ -589,7 +671,7 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt
// encode message body // encode message body
cf, err := r.newCodec(msg.ContentType()) cf, err := r.newCodec(msg.ContentType())
if err != nil { if err != nil {
return errors.InternalServerError("go.micro.client", err.Error()) return merrors.InternalServerError(packageID, err.Error())
} }
var body []byte var body []byte
@ -598,33 +680,38 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt
if d, ok := msg.Payload().(*raw.Frame); ok { if d, ok := msg.Payload().(*raw.Frame); ok {
body = d.Data body = d.Data
} else { } else {
// new buffer
b := buf.New(nil) b := buf.New(nil)
if err := cf(b).Write(&codec.Message{ if err = cf(b).Write(&codec.Message{
Target: topic, Target: topic,
Type: codec.Event, Type: codec.Event,
Header: map[string]string{ Header: map[string]string{
"Micro-Id": id, headers.ID: id,
"Micro-Topic": msg.Topic(), headers.Message: msg.Topic(),
}, },
}, msg.Payload()); err != nil { }, msg.Payload()); err != nil {
return errors.InternalServerError("go.micro.client", err.Error()) return merrors.InternalServerError(packageID, err.Error())
} }
// set the body // set the body
body = b.Bytes() body = b.Bytes()
} }
if !r.once.Load().(bool) { l, ok := r.once.Load().(bool)
if !ok {
return fmt.Errorf("failed to cast to bool")
}
if !l {
if err = r.opts.Broker.Connect(); err != nil { if err = r.opts.Broker.Connect(); err != nil {
return errors.InternalServerError("go.micro.client", err.Error()) return merrors.InternalServerError(packageID, err.Error())
} }
r.once.Store(true) r.once.Store(true)
} }
return r.opts.Broker.Publish(topic, &broker.Message{ return r.opts.Broker.Publish(topic, &broker.Message{
Header: md, Header: metadata,
Body: body, Body: body,
}, broker.PublishContext(options.Context)) }, broker.PublishContext(options.Context))
} }

View File

@ -10,18 +10,23 @@ import (
"go-micro.dev/v4/selector" "go-micro.dev/v4/selector"
) )
const (
serviceName = "test.service"
serviceEndpoint = "Test.Endpoint"
)
func newTestRegistry() registry.Registry { func newTestRegistry() registry.Registry {
return registry.NewMemoryRegistry(registry.Services(testData)) return registry.NewMemoryRegistry(registry.Services(testData))
} }
func TestCallAddress(t *testing.T) { func TestCallAddress(t *testing.T) {
var called bool var called bool
service := "test.service" service := serviceName
endpoint := "Test.Endpoint" endpoint := serviceEndpoint
address := "10.1.10.1:8080" address := "10.1.10.1:8080"
wrap := func(cf CallFunc) CallFunc { wrap := func(cf CallFunc) CallFunc {
return func(ctx context.Context, node *registry.Node, req Request, rsp interface{}, opts CallOptions) error { return func(_ context.Context, node *registry.Node, req Request, _ interface{}, _ CallOptions) error {
called = true called = true
if req.Service() != service { if req.Service() != service {
@ -46,7 +51,10 @@ func TestCallAddress(t *testing.T) {
Registry(r), Registry(r),
WrapCall(wrap), WrapCall(wrap),
) )
c.Options().Selector.Init(selector.Registry(r))
if err := c.Options().Selector.Init(selector.Registry(r)); err != nil {
t.Fatal("failed to initialize selector", err)
}
req := c.NewRequest(service, endpoint, nil) req := c.NewRequest(service, endpoint, nil)
@ -68,7 +76,7 @@ func TestCallRetry(t *testing.T) {
var called int var called int
wrap := func(cf CallFunc) CallFunc { wrap := func(cf CallFunc) CallFunc {
return func(ctx context.Context, node *registry.Node, req Request, rsp interface{}, opts CallOptions) error { return func(_ context.Context, _ *registry.Node, _ Request, _ interface{}, _ CallOptions) error {
called++ called++
if called == 1 { if called == 1 {
return errors.InternalServerError("test.error", "retry request") return errors.InternalServerError("test.error", "retry request")
@ -84,7 +92,10 @@ func TestCallRetry(t *testing.T) {
Registry(r), Registry(r),
WrapCall(wrap), WrapCall(wrap),
) )
c.Options().Selector.Init(selector.Registry(r))
if err := c.Options().Selector.Init(selector.Registry(r)); err != nil {
t.Fatal("failed to initialize selector", err)
}
req := c.NewRequest(service, endpoint, nil) req := c.NewRequest(service, endpoint, nil)
@ -107,7 +118,7 @@ func TestCallWrapper(t *testing.T) {
address := "10.1.10.1:8080" address := "10.1.10.1:8080"
wrap := func(cf CallFunc) CallFunc { wrap := func(cf CallFunc) CallFunc {
return func(ctx context.Context, node *registry.Node, req Request, rsp interface{}, opts CallOptions) error { return func(_ context.Context, node *registry.Node, req Request, _ interface{}, _ CallOptions) error {
called = true called = true
if req.Service() != service { if req.Service() != service {
@ -132,9 +143,12 @@ func TestCallWrapper(t *testing.T) {
Registry(r), Registry(r),
WrapCall(wrap), WrapCall(wrap),
) )
c.Options().Selector.Init(selector.Registry(r))
r.Register(&registry.Service{ if err := c.Options().Selector.Init(selector.Registry(r)); err != nil {
t.Fatal("failed to initialize selector", err)
}
err := r.Register(&registry.Service{
Name: service, Name: service,
Version: "latest", Version: "latest",
Nodes: []*registry.Node{ Nodes: []*registry.Node{
@ -147,6 +161,9 @@ func TestCallWrapper(t *testing.T) {
}, },
}, },
}) })
if err != nil {
t.Fatal("failed to register service", err)
}
req := c.NewRequest(service, endpoint, nil) req := c.NewRequest(service, endpoint, nil)
if err := c.Call(context.Background(), req, nil); err != nil { if err := c.Call(context.Background(), req, nil); err != nil {

View File

@ -14,6 +14,7 @@ import (
"go-micro.dev/v4/errors" "go-micro.dev/v4/errors"
"go-micro.dev/v4/registry" "go-micro.dev/v4/registry"
"go-micro.dev/v4/transport" "go-micro.dev/v4/transport"
"go-micro.dev/v4/transport/headers"
) )
const ( const (
@ -50,8 +51,10 @@ type readWriteCloser struct {
} }
var ( var (
// DefaultContentType header.
DefaultContentType = "application/json" DefaultContentType = "application/json"
// DefaultCodecs map.
DefaultCodecs = map[string]codec.NewCodec{ DefaultCodecs = map[string]codec.NewCodec{
"application/grpc": grpc.NewCodec, "application/grpc": grpc.NewCodec,
"application/grpc+json": grpc.NewCodec, "application/grpc+json": grpc.NewCodec,
@ -84,6 +87,7 @@ func (rwc *readWriteCloser) Write(p []byte) (n int, err error) {
func (rwc *readWriteCloser) Close() error { func (rwc *readWriteCloser) Close() error {
rwc.rbuf.Reset() rwc.rbuf.Reset()
rwc.wbuf.Reset() rwc.wbuf.Reset()
return nil return nil
} }
@ -92,20 +96,21 @@ func getHeaders(m *codec.Message) {
if len(v) > 0 { if len(v) > 0 {
return v return v
} }
return m.Header[hdr] return m.Header[hdr]
} }
// check error in header // check error in header
m.Error = set(m.Error, "Micro-Error") m.Error = set(m.Error, headers.Error)
// check endpoint in header // check endpoint in header
m.Endpoint = set(m.Endpoint, "Micro-Endpoint") m.Endpoint = set(m.Endpoint, headers.Endpoint)
// check method in header // check method in header
m.Method = set(m.Method, "Micro-Method") m.Method = set(m.Method, headers.Method)
// set the request id // set the request id
m.Id = set(m.Id, "Micro-Id") m.Id = set(m.Id, headers.ID)
} }
func setHeaders(m *codec.Message, stream string) { func setHeaders(m *codec.Message, stream string) {
@ -113,17 +118,18 @@ func setHeaders(m *codec.Message, stream string) {
if len(v) == 0 { if len(v) == 0 {
return return
} }
m.Header[hdr] = v m.Header[hdr] = v
} }
set("Micro-Id", m.Id) set(headers.ID, m.Id)
set("Micro-Service", m.Target) set(headers.Request, m.Target)
set("Micro-Method", m.Method) set(headers.Method, m.Method)
set("Micro-Endpoint", m.Endpoint) set(headers.Endpoint, m.Endpoint)
set("Micro-Error", m.Error) set(headers.Error, m.Error)
if len(stream) > 0 { if len(stream) > 0 {
set("Micro-Stream", stream) set(headers.Stream, stream)
} }
} }
@ -137,7 +143,7 @@ func setupProtocol(msg *transport.Message, node *registry.Node) codec.NewCodec {
} }
// processing topic publishing // processing topic publishing
if len(msg.Header["Micro-Topic"]) > 0 { if len(msg.Header[headers.Message]) > 0 {
return nil return nil
} }
@ -149,60 +155,59 @@ func setupProtocol(msg *transport.Message, node *registry.Node) codec.NewCodec {
msg.Header["Content-Type"] = "application/proto-rpc" msg.Header["Content-Type"] = "application/proto-rpc"
} }
// now return codec
return defaultCodecs[msg.Header["Content-Type"]] return defaultCodecs[msg.Header["Content-Type"]]
} }
func newRpcCodec(req *transport.Message, client transport.Client, c codec.NewCodec, stream string) codec.Codec { func newRPCCodec(req *transport.Message, client transport.Client, c codec.NewCodec, stream string) codec.Codec {
rwc := &readWriteCloser{ rwc := &readWriteCloser{
wbuf: bytes.NewBuffer(nil), wbuf: bytes.NewBuffer(nil),
rbuf: bytes.NewBuffer(nil), rbuf: bytes.NewBuffer(nil),
} }
r := &rpcCodec{
return &rpcCodec{
buf: rwc, buf: rwc,
client: client, client: client,
codec: c(rwc), codec: c(rwc),
req: req, req: req,
stream: stream, stream: stream,
} }
return r
} }
func (c *rpcCodec) Write(m *codec.Message, body interface{}) error { func (c *rpcCodec) Write(message *codec.Message, body interface{}) error {
c.buf.wbuf.Reset() c.buf.wbuf.Reset()
// create header // create header
if m.Header == nil { if message.Header == nil {
m.Header = map[string]string{} message.Header = map[string]string{}
} }
// copy original header // copy original header
for k, v := range c.req.Header { for k, v := range c.req.Header {
m.Header[k] = v message.Header[k] = v
} }
// set the mucp headers // set the mucp headers
setHeaders(m, c.stream) setHeaders(message, c.stream)
// if body is bytes Frame don't encode // if body is bytes Frame don't encode
if body != nil { if body != nil {
if b, ok := body.(*raw.Frame); ok { if b, ok := body.(*raw.Frame); ok {
// set body // set body
m.Body = b.Data message.Body = b.Data
} else { } else {
// write to codec // write to codec
if err := c.codec.Write(m, body); err != nil { if err := c.codec.Write(message, body); err != nil {
return errors.InternalServerError("go.micro.client.codec", err.Error()) return errors.InternalServerError("go.micro.client.codec", err.Error())
} }
// set body // set body
m.Body = c.buf.wbuf.Bytes() message.Body = c.buf.wbuf.Bytes()
} }
} }
// create new transport message // create new transport message
msg := transport.Message{ msg := transport.Message{
Header: m.Header, Header: message.Header,
Body: m.Body, Body: message.Body,
} }
// send the request // send the request
@ -213,7 +218,7 @@ func (c *rpcCodec) Write(m *codec.Message, body interface{}) error {
return nil return nil
} }
func (c *rpcCodec) ReadHeader(m *codec.Message, r codec.MessageType) error { func (c *rpcCodec) ReadHeader(msg *codec.Message, r codec.MessageType) error {
var tm transport.Message var tm transport.Message
// read message from transport // read message from transport
@ -225,13 +230,13 @@ func (c *rpcCodec) ReadHeader(m *codec.Message, r codec.MessageType) error {
c.buf.rbuf.Write(tm.Body) c.buf.rbuf.Write(tm.Body)
// set headers from transport // set headers from transport
m.Header = tm.Header msg.Header = tm.Header
// read header // read header
err := c.codec.ReadHeader(m, r) err := c.codec.ReadHeader(msg, r)
// get headers // get headers
getHeaders(m) getHeaders(msg)
// return header error // return header error
if err != nil { if err != nil {
@ -252,15 +257,23 @@ func (c *rpcCodec) ReadBody(b interface{}) error {
if err := c.codec.ReadBody(b); err != nil { if err := c.codec.ReadBody(b); err != nil {
return errors.InternalServerError("go.micro.client.codec", err.Error()) return errors.InternalServerError("go.micro.client.codec", err.Error())
} }
return nil return nil
} }
func (c *rpcCodec) Close() error { func (c *rpcCodec) Close() error {
c.buf.Close() if err := c.buf.Close(); err != nil {
c.codec.Close() return err
}
if err := c.codec.Close(); err != nil {
return err
}
if err := c.client.Close(); err != nil { if err := c.client.Close(); err != nil {
return errors.InternalServerError("go.micro.client.transport", err.Error()) return errors.InternalServerError("go.micro.client.transport", err.Error())
} }
return nil return nil
} }

View File

@ -12,8 +12,10 @@ import (
// Implements the streamer interface. // Implements the streamer interface.
type rpcStream struct { type rpcStream struct {
sync.RWMutex sync.RWMutex
id string id string
closed chan bool closed chan bool
// Indicates whether connection should be closed directly.
close bool
err error err error
request Request request Request
response Response response Response
@ -79,6 +81,7 @@ func (r *rpcStream) Recv(msg interface{}) error {
if r.isClosed() { if r.isClosed() {
r.err = errShutdown r.err = errShutdown
r.Unlock() r.Unlock()
return errShutdown return errShutdown
} }
@ -87,15 +90,19 @@ func (r *rpcStream) Recv(msg interface{}) error {
r.Unlock() r.Unlock()
err := r.codec.ReadHeader(&resp, codec.Response) err := r.codec.ReadHeader(&resp, codec.Response)
r.Lock() r.Lock()
if err != nil { if err != nil {
if err == io.EOF && !r.isClosed() { if errors.Is(err, io.EOF) && !r.isClosed() {
r.err = io.ErrUnexpectedEOF r.err = io.ErrUnexpectedEOF
r.Unlock() r.Unlock()
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
r.err = err r.err = err
r.Unlock() r.Unlock()
return err return err
} }
@ -124,13 +131,15 @@ func (r *rpcStream) Recv(msg interface{}) error {
} }
} }
r.Unlock() defer r.Unlock()
return r.err return r.err
} }
func (r *rpcStream) Error() error { func (r *rpcStream) Error() error {
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()
return r.err return r.err
} }
@ -152,6 +161,7 @@ func (r *rpcStream) Close() error {
// send the end of stream message // send the end of stream message
if r.sendEOS { if r.sendEOS {
// no need to check for error // no need to check for error
//nolint:errcheck,gosec
r.codec.Write(&codec.Message{ r.codec.Write(&codec.Message{
Id: r.id, Id: r.id,
Target: r.request.Service(), Target: r.request.Service(),
@ -164,10 +174,13 @@ func (r *rpcStream) Close() error {
err := r.codec.Close() err := r.codec.Close()
rerr := r.Error()
if r.close && rerr == nil {
rerr = errors.New("connection header set to close")
}
// release the connection // release the connection
r.release(r.Error()) r.release(rerr)
// return the codec error
return err return err
} }
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"go-micro.dev/v4/codec" "go-micro.dev/v4/codec"
"go-micro.dev/v4/transport/headers"
) )
type Codec struct { type Codec struct {
@ -29,8 +30,8 @@ func (c *Codec) ReadHeader(m *codec.Message, t codec.MessageType) error {
// service method // service method
path := m.Header[":path"] path := m.Header[":path"]
if len(path) == 0 || path[0] != '/' { if len(path) == 0 || path[0] != '/' {
m.Target = m.Header["Micro-Service"] m.Target = m.Header[headers.Request]
m.Endpoint = m.Header["Micro-Endpoint"] m.Endpoint = m.Header[headers.Endpoint]
} else { } else {
// [ , a.package.Foo, Bar] // [ , a.package.Foo, Bar]
parts := strings.Split(path, "/") parts := strings.Split(path, "/")

View File

@ -3,6 +3,8 @@ package handler
import ( import (
"context" "context"
"errors"
"io"
"time" "time"
"go-micro.dev/v4/client" "go-micro.dev/v4/client"
@ -10,7 +12,6 @@ import (
proto "go-micro.dev/v4/debug/proto" proto "go-micro.dev/v4/debug/proto"
"go-micro.dev/v4/debug/stats" "go-micro.dev/v4/debug/stats"
"go-micro.dev/v4/debug/trace" "go-micro.dev/v4/debug/trace"
"go-micro.dev/v4/server"
) )
// NewHandler returns an instance of the Debug Handler. // NewHandler returns an instance of the Debug Handler.
@ -22,6 +23,8 @@ func NewHandler(c client.Client) *Debug {
} }
} }
var _ proto.DebugHandler = (*Debug)(nil)
type Debug struct { type Debug struct {
// must honor the debug handler // must honor the debug handler
proto.DebugHandler proto.DebugHandler
@ -38,6 +41,25 @@ func (d *Debug) Health(ctx context.Context, req *proto.HealthRequest, rsp *proto
return nil return nil
} }
func (d *Debug) MessageBus(ctx context.Context, stream proto.Debug_MessageBusStream) error {
for {
_, err := stream.Recv()
if errors.Is(err, io.EOF) {
return nil
} else if err != nil {
return err
}
rsp := proto.BusMsg{
Msg: "Request received!",
}
if err := stream.Send(&rsp); err != nil {
return err
}
}
}
func (d *Debug) Stats(ctx context.Context, req *proto.StatsRequest, rsp *proto.StatsResponse) error { func (d *Debug) Stats(ctx context.Context, req *proto.StatsRequest, rsp *proto.StatsResponse) error {
stats, err := d.stats.Read() stats, err := d.stats.Read()
if err != nil { if err != nil {
@ -92,11 +114,7 @@ func (d *Debug) Trace(ctx context.Context, req *proto.TraceRequest, rsp *proto.T
return nil return nil
} }
func (d *Debug) Log(ctx context.Context, stream server.Stream) error { func (d *Debug) Log(ctx context.Context, req *proto.LogRequest, stream proto.Debug_LogStream) error {
req := new(proto.LogRequest)
if err := stream.Recv(req); err != nil {
return err
}
var options []log.ReadOption var options []log.ReadOption

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,7 @@ package debug
import ( import (
fmt "fmt" fmt "fmt"
proto "github.com/golang/protobuf/proto" proto "google.golang.org/protobuf/proto"
math "math" math "math"
) )
@ -21,12 +21,6 @@ var _ = proto.Marshal
var _ = fmt.Errorf var _ = fmt.Errorf
var _ = math.Inf var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
// Reference imports to suppress errors if they are not otherwise used. // Reference imports to suppress errors if they are not otherwise used.
var _ api.Endpoint var _ api.Endpoint
var _ context.Context var _ context.Context
@ -46,6 +40,7 @@ type DebugService interface {
Health(ctx context.Context, in *HealthRequest, opts ...client.CallOption) (*HealthResponse, error) Health(ctx context.Context, in *HealthRequest, opts ...client.CallOption) (*HealthResponse, error)
Stats(ctx context.Context, in *StatsRequest, opts ...client.CallOption) (*StatsResponse, error) Stats(ctx context.Context, in *StatsRequest, opts ...client.CallOption) (*StatsResponse, error)
Trace(ctx context.Context, in *TraceRequest, opts ...client.CallOption) (*TraceResponse, error) Trace(ctx context.Context, in *TraceRequest, opts ...client.CallOption) (*TraceResponse, error)
MessageBus(ctx context.Context, opts ...client.CallOption) (Debug_MessageBusService, error)
} }
type debugService struct { type debugService struct {
@ -76,6 +71,7 @@ type Debug_LogService interface {
Context() context.Context Context() context.Context
SendMsg(interface{}) error SendMsg(interface{}) error
RecvMsg(interface{}) error RecvMsg(interface{}) error
CloseSend() error
Close() error Close() error
Recv() (*Record, error) Recv() (*Record, error)
} }
@ -84,6 +80,10 @@ type debugServiceLog struct {
stream client.Stream stream client.Stream
} }
func (x *debugServiceLog) CloseSend() error {
return x.stream.CloseSend()
}
func (x *debugServiceLog) Close() error { func (x *debugServiceLog) Close() error {
return x.stream.Close() return x.stream.Close()
} }
@ -139,6 +139,62 @@ func (c *debugService) Trace(ctx context.Context, in *TraceRequest, opts ...clie
return out, nil return out, nil
} }
func (c *debugService) MessageBus(ctx context.Context, opts ...client.CallOption) (Debug_MessageBusService, error) {
req := c.c.NewRequest(c.name, "Debug.MessageBus", &BusMsg{})
stream, err := c.c.Stream(ctx, req, opts...)
if err != nil {
return nil, err
}
return &debugServiceMessageBus{stream}, nil
}
type Debug_MessageBusService interface {
Context() context.Context
SendMsg(interface{}) error
RecvMsg(interface{}) error
CloseSend() error
Close() error
Send(*BusMsg) error
Recv() (*BusMsg, error)
}
type debugServiceMessageBus struct {
stream client.Stream
}
func (x *debugServiceMessageBus) CloseSend() error {
return x.stream.CloseSend()
}
func (x *debugServiceMessageBus) Close() error {
return x.stream.Close()
}
func (x *debugServiceMessageBus) Context() context.Context {
return x.stream.Context()
}
func (x *debugServiceMessageBus) SendMsg(m interface{}) error {
return x.stream.Send(m)
}
func (x *debugServiceMessageBus) RecvMsg(m interface{}) error {
return x.stream.Recv(m)
}
func (x *debugServiceMessageBus) Send(m *BusMsg) error {
return x.stream.Send(m)
}
func (x *debugServiceMessageBus) Recv() (*BusMsg, error) {
m := new(BusMsg)
err := x.stream.Recv(m)
if err != nil {
return nil, err
}
return m, nil
}
// Server API for Debug service // Server API for Debug service
type DebugHandler interface { type DebugHandler interface {
@ -146,6 +202,7 @@ type DebugHandler interface {
Health(context.Context, *HealthRequest, *HealthResponse) error Health(context.Context, *HealthRequest, *HealthResponse) error
Stats(context.Context, *StatsRequest, *StatsResponse) error Stats(context.Context, *StatsRequest, *StatsResponse) error
Trace(context.Context, *TraceRequest, *TraceResponse) error Trace(context.Context, *TraceRequest, *TraceResponse) error
MessageBus(context.Context, Debug_MessageBusStream) error
} }
func RegisterDebugHandler(s server.Server, hdlr DebugHandler, opts ...server.HandlerOption) error { func RegisterDebugHandler(s server.Server, hdlr DebugHandler, opts ...server.HandlerOption) error {
@ -154,6 +211,7 @@ func RegisterDebugHandler(s server.Server, hdlr DebugHandler, opts ...server.Han
Health(ctx context.Context, in *HealthRequest, out *HealthResponse) error Health(ctx context.Context, in *HealthRequest, out *HealthResponse) error
Stats(ctx context.Context, in *StatsRequest, out *StatsResponse) error Stats(ctx context.Context, in *StatsRequest, out *StatsResponse) error
Trace(ctx context.Context, in *TraceRequest, out *TraceResponse) error Trace(ctx context.Context, in *TraceRequest, out *TraceResponse) error
MessageBus(ctx context.Context, stream server.Stream) error
} }
type Debug struct { type Debug struct {
debug debug
@ -217,3 +275,48 @@ func (h *debugHandler) Stats(ctx context.Context, in *StatsRequest, out *StatsRe
func (h *debugHandler) Trace(ctx context.Context, in *TraceRequest, out *TraceResponse) error { func (h *debugHandler) Trace(ctx context.Context, in *TraceRequest, out *TraceResponse) error {
return h.DebugHandler.Trace(ctx, in, out) return h.DebugHandler.Trace(ctx, in, out)
} }
func (h *debugHandler) MessageBus(ctx context.Context, stream server.Stream) error {
return h.DebugHandler.MessageBus(ctx, &debugMessageBusStream{stream})
}
type Debug_MessageBusStream interface {
Context() context.Context
SendMsg(interface{}) error
RecvMsg(interface{}) error
Close() error
Send(*BusMsg) error
Recv() (*BusMsg, error)
}
type debugMessageBusStream struct {
stream server.Stream
}
func (x *debugMessageBusStream) Close() error {
return x.stream.Close()
}
func (x *debugMessageBusStream) Context() context.Context {
return x.stream.Context()
}
func (x *debugMessageBusStream) SendMsg(m interface{}) error {
return x.stream.Send(m)
}
func (x *debugMessageBusStream) RecvMsg(m interface{}) error {
return x.stream.Recv(m)
}
func (x *debugMessageBusStream) Send(m *BusMsg) error {
return x.stream.Send(m)
}
func (x *debugMessageBusStream) Recv() (*BusMsg, error) {
m := new(BusMsg)
if err := x.stream.Recv(m); err != nil {
return nil, err
}
return m, nil
}

View File

@ -1,99 +1,107 @@
syntax = "proto3"; syntax = "proto3";
package debug;
option go_package = "./proto;debug";
// Compile this proto by running the following command in the debug directory:
// protoc --proto_path=. --micro_out=. --go_out=:. proto/debug.proto
service Debug { service Debug {
rpc Log(LogRequest) returns (stream Record) {}; rpc Log(LogRequest) returns (stream Record) {};
rpc Health(HealthRequest) returns (HealthResponse) {}; rpc Health(HealthRequest) returns (HealthResponse) {};
rpc Stats(StatsRequest) returns (StatsResponse) {}; rpc Stats(StatsRequest) returns (StatsResponse) {};
rpc Trace(TraceRequest) returns (TraceResponse) {}; rpc Trace(TraceRequest) returns (TraceResponse) {};
rpc MessageBus(stream BusMsg) returns (stream BusMsg) {};
} }
message BusMsg { string msg = 1; }
message HealthRequest { message HealthRequest {
// optional service name // optional service name
string service = 1; string service = 1;
} }
message HealthResponse { message HealthResponse {
// default: ok // default: ok
string status = 1; string status = 1;
} }
message StatsRequest { message StatsRequest {
// optional service name // optional service name
string service = 1; string service = 1;
} }
message StatsResponse { message StatsResponse {
// timestamp of recording // timestamp of recording
uint64 timestamp = 1; uint64 timestamp = 1;
// unix timestamp // unix timestamp
uint64 started = 2; uint64 started = 2;
// in seconds // in seconds
uint64 uptime = 3; uint64 uptime = 3;
// in bytes // in bytes
uint64 memory = 4; uint64 memory = 4;
// num threads // num threads
uint64 threads = 5; uint64 threads = 5;
// total gc in nanoseconds // total gc in nanoseconds
uint64 gc = 6; uint64 gc = 6;
// total number of requests // total number of requests
uint64 requests = 7; uint64 requests = 7;
// total number of errors // total number of errors
uint64 errors = 8; uint64 errors = 8;
} }
// LogRequest requests service logs // LogRequest requests service logs
message LogRequest { message LogRequest {
// service to request logs for // service to request logs for
string service = 1; string service = 1;
// stream records continuously // stream records continuously
bool stream = 2; bool stream = 2;
// count of records to request // count of records to request
int64 count = 3; int64 count = 3;
// relative time in seconds // relative time in seconds
// before the current time // before the current time
// from which to show logs // from which to show logs
int64 since = 4; int64 since = 4;
} }
// Record is service log record // Record is service log record
// Also used as default basic message type to test requests.
message Record { message Record {
// timestamp of log record // timestamp of log record
int64 timestamp = 1; int64 timestamp = 1;
// record metadata // record metadata
map<string,string> metadata = 2; map<string, string> metadata = 2;
// message // message
string message = 3; string message = 3;
} }
message TraceRequest { message TraceRequest {
// trace id to retrieve // trace id to retrieve
string id = 1; string id = 1;
}
message TraceResponse {
repeated Span spans = 1;
} }
message TraceResponse { repeated Span spans = 1; }
enum SpanType { enum SpanType {
INBOUND = 0; INBOUND = 0;
OUTBOUND = 1; OUTBOUND = 1;
} }
message Span { message Span {
// the trace id // the trace id
string trace = 1; string trace = 1;
// id of the span // id of the span
string id = 2; string id = 2;
// parent span // parent span
string parent = 3; string parent = 3;
// name of the resource // name of the resource
string name = 4; string name = 4;
// time of start in nanoseconds // time of start in nanoseconds
uint64 started = 5; uint64 started = 5;
// duration of the execution in nanoseconds // duration of the execution in nanoseconds
uint64 duration = 6; uint64 duration = 6;
// associated metadata // associated metadata
map<string,string> metadata = 7; map<string, string> metadata = 7;
SpanType type = 8; SpanType type = 8;
} }

21
debug/trace/noop.go Normal file
View File

@ -0,0 +1,21 @@
package trace
import "context"
type noop struct{}
func (n *noop) Init(...Option) error {
return nil
}
func (n *noop) Start(ctx context.Context, name string) (context.Context, *Span) {
return nil, nil
}
func (n *noop) Finish(*Span) error {
return nil
}
func (n *noop) Read(...ReadOption) ([]*Span, error) {
return nil, nil
}

View File

@ -6,6 +6,12 @@ import (
"time" "time"
"go-micro.dev/v4/metadata" "go-micro.dev/v4/metadata"
"go-micro.dev/v4/transport/headers"
)
var (
// DefaultTracer is the default tracer.
DefaultTracer = NewTracer()
) )
// Tracer is an interface for distributed tracing. // Tracer is an interface for distributed tracing.
@ -48,52 +54,29 @@ type Span struct {
Type SpanType Type SpanType
} }
const (
traceIDKey = "Micro-Trace-Id"
spanIDKey = "Micro-Span-Id"
)
// FromContext returns a span from context. // FromContext returns a span from context.
func FromContext(ctx context.Context) (traceID string, parentSpanID string, isFound bool) { func FromContext(ctx context.Context) (traceID string, parentSpanID string, isFound bool) {
traceID, traceOk := metadata.Get(ctx, traceIDKey) traceID, traceOk := metadata.Get(ctx, headers.TraceIDKey)
microID, microOk := metadata.Get(ctx, "Micro-Id") microID, microOk := metadata.Get(ctx, headers.ID)
if !traceOk && !microOk { if !traceOk && !microOk {
isFound = false isFound = false
return return
} }
if !traceOk { if !traceOk {
traceID = microID traceID = microID
} }
parentSpanID, ok := metadata.Get(ctx, spanIDKey)
parentSpanID, ok := metadata.Get(ctx, headers.SpanID)
return traceID, parentSpanID, ok return traceID, parentSpanID, ok
} }
// ToContext saves the trace and span ids in the context. // ToContext saves the trace and span ids in the context.
func ToContext(ctx context.Context, traceID, parentSpanID string) context.Context { func ToContext(ctx context.Context, traceID, parentSpanID string) context.Context {
return metadata.MergeContext(ctx, map[string]string{ return metadata.MergeContext(ctx, map[string]string{
traceIDKey: traceID, headers.TraceIDKey: traceID,
spanIDKey: parentSpanID, headers.SpanID: parentSpanID,
}, true) }, true)
} }
var (
DefaultTracer Tracer = NewTracer()
)
type noop struct{}
func (n *noop) Init(...Option) error {
return nil
}
func (n *noop) Start(ctx context.Context, name string) (context.Context, *Span) {
return nil, nil
}
func (n *noop) Finish(*Span) error {
return nil
}
func (n *noop) Read(...ReadOption) ([]*Span, error) {
return nil, nil
}

4
go.mod
View File

@ -1,6 +1,6 @@
module go-micro.dev/v4 module go-micro.dev/v4
go 1.17 go 1.18
require ( require (
github.com/bitly/go-simplejson v0.5.0 github.com/bitly/go-simplejson v0.5.0
@ -69,7 +69,7 @@ require (
github.com/sirupsen/logrus v1.7.0 // indirect github.com/sirupsen/logrus v1.7.0 // indirect
github.com/xanzy/ssh-agent v0.3.0 // indirect github.com/xanzy/ssh-agent v0.3.0 // indirect
go.opencensus.io v0.22.3 // indirect go.opencensus.io v0.22.3 // indirect
golang.org/x/sys v0.0.0-20210502180810-71e4cd670f79 // indirect golang.org/x/sys v0.0.0-20210510120138-977fb7262007 // indirect
golang.org/x/text v0.3.6 // indirect golang.org/x/text v0.3.6 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect

5
go.sum
View File

@ -507,7 +507,6 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1
github.com/transip/gotransip/v6 v6.2.0/go.mod h1:pQZ36hWWRahCUXkFWlx9Hs711gLd8J4qdgLdRzmtY+g= github.com/transip/gotransip/v6 v6.2.0/go.mod h1:pQZ36hWWRahCUXkFWlx9Hs711gLd8J4qdgLdRzmtY+g=
github.com/uber-go/atomic v1.3.2/go.mod h1:/Ct5t2lcmbJ4OSe/waGBoaVvVqtO0bmtfVNex1PFV8g= github.com/uber-go/atomic v1.3.2/go.mod h1:/Ct5t2lcmbJ4OSe/waGBoaVvVqtO0bmtfVNex1PFV8g=
github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA=
github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M= github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
@ -689,9 +688,9 @@ golang.org/x/sys v0.0.0-20210216224549-f992740a1bac/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210502180810-71e4cd670f79 h1:RX8C8PRZc2hTIod4ds8ij+/4RQX3AqhYj3uOHmyaz4E=
golang.org/x/sys v0.0.0-20210502180810-71e4cd670f79/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210502180810-71e4cd670f79/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007 h1:gG67DSER+11cZvqIMb8S8bt0vZtiN6xWYARwirrOSfE=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201113234701-d7a72108b828/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201113234701-d7a72108b828/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=

View File

@ -32,6 +32,7 @@ func (l *defaultLogger) Init(opts ...Option) error {
for _, o := range opts { for _, o := range opts {
o(&l.opts) o(&l.opts)
} }
return nil return nil
} }
@ -42,6 +43,7 @@ func (l *defaultLogger) String() string {
func (l *defaultLogger) Fields(fields map[string]interface{}) Logger { func (l *defaultLogger) Fields(fields map[string]interface{}) Logger {
l.Lock() l.Lock()
nfields := make(map[string]interface{}, len(l.opts.Fields)) nfields := make(map[string]interface{}, len(l.opts.Fields))
for k, v := range l.opts.Fields { for k, v := range l.opts.Fields {
nfields[k] = v nfields[k] = v
} }
@ -65,6 +67,7 @@ func copyFields(src map[string]interface{}) map[string]interface{} {
for k, v := range src { for k, v := range src {
dst[k] = v dst[k] = v
} }
return dst return dst
} }
@ -85,10 +88,13 @@ func logCallerfilePath(loggingFilePath string) string {
if idx == -1 { if idx == -1 {
return loggingFilePath return loggingFilePath
} }
idx = strings.LastIndexByte(loggingFilePath[:idx], '/') idx = strings.LastIndexByte(loggingFilePath[:idx], '/')
if idx == -1 { if idx == -1 {
return loggingFilePath return loggingFilePath
} }
return loggingFilePath[idx+1:] return loggingFilePath[idx+1:]
} }
@ -121,6 +127,7 @@ func (l *defaultLogger) Log(level Level, v ...interface{}) {
} }
sort.Strings(keys) sort.Strings(keys)
metadata := "" metadata := ""
for _, k := range keys { for _, k := range keys {
@ -162,6 +169,7 @@ func (l *defaultLogger) Logf(level Level, format string, v ...interface{}) {
} }
sort.Strings(keys) sort.Strings(keys)
metadata := "" metadata := ""
for _, k := range keys { for _, k := range keys {
@ -177,9 +185,11 @@ func (l *defaultLogger) Logf(level Level, format string, v ...interface{}) {
func (l *defaultLogger) Options() Options { func (l *defaultLogger) Options() Options {
// not guard against options Context values // not guard against options Context values
l.RLock() l.RLock()
defer l.RUnlock()
opts := l.opts opts := l.opts
opts.Fields = copyFields(l.opts.Fields) opts.Fields = copyFields(l.opts.Fields)
l.RUnlock()
return opts return opts
} }

View File

@ -62,6 +62,7 @@ func NewEvent(topic string, c client.Client) Event {
if c == nil { if c == nil {
c = client.NewClient() c = client.NewClient()
} }
return &event{c, topic} return &event{c, topic}
} }

View File

@ -7,10 +7,11 @@ import (
"sync" "sync"
"time" "time"
"golang.org/x/sync/singleflight"
log "go-micro.dev/v4/logger" log "go-micro.dev/v4/logger"
"go-micro.dev/v4/registry" "go-micro.dev/v4/registry"
util "go-micro.dev/v4/util/registry" util "go-micro.dev/v4/util/registry"
"golang.org/x/sync/singleflight"
) )
// Cache is the registry cache interface. // Cache is the registry cache interface.
@ -464,6 +465,7 @@ func (c *cache) String() string {
// New returns a new cache. // New returns a new cache.
func New(r registry.Registry, opts ...Option) Cache { func New(r registry.Registry, opts ...Option) Cache {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
options := Options{ options := Options{
TTL: DefaultTTL, TTL: DefaultTTL,
Logger: log.DefaultLogger, Logger: log.DefaultLogger,

View File

@ -1,8 +1,11 @@
package selector package selector
import ( import (
"sync"
"time" "time"
"github.com/pkg/errors"
"go-micro.dev/v4/registry" "go-micro.dev/v4/registry"
"go-micro.dev/v4/registry/cache" "go-micro.dev/v4/registry/cache"
) )
@ -10,19 +13,25 @@ import (
type registrySelector struct { type registrySelector struct {
so Options so Options
rc cache.Cache rc cache.Cache
mu sync.RWMutex
} }
func (c *registrySelector) newCache() cache.Cache { func (c *registrySelector) newCache() cache.Cache {
opts := make([]cache.Option, 0, 1) opts := make([]cache.Option, 0, 1)
if c.so.Context != nil { if c.so.Context != nil {
if t, ok := c.so.Context.Value("selector_ttl").(time.Duration); ok { if t, ok := c.so.Context.Value("selector_ttl").(time.Duration); ok {
opts = append(opts, cache.WithTTL(t)) opts = append(opts, cache.WithTTL(t))
} }
} }
return cache.New(c.so.Registry, opts...) return cache.New(c.so.Registry, opts...)
} }
func (c *registrySelector) Init(opts ...Option) error { func (c *registrySelector) Init(opts ...Option) error {
c.mu.Lock()
defer c.mu.Unlock()
for _, o := range opts { for _, o := range opts {
o(&c.so) o(&c.so)
} }
@ -38,6 +47,9 @@ func (c *registrySelector) Options() Options {
} }
func (c *registrySelector) Select(service string, opts ...SelectOption) (Next, error) { func (c *registrySelector) Select(service string, opts ...SelectOption) (Next, error) {
c.mu.RLock()
defer c.mu.RUnlock()
sopts := SelectOptions{ sopts := SelectOptions{
Strategy: c.so.Strategy, Strategy: c.so.Strategy,
} }
@ -51,9 +63,10 @@ func (c *registrySelector) Select(service string, opts ...SelectOption) (Next, e
// if that fails go directly to the registry // if that fails go directly to the registry
services, err := c.rc.GetService(service) services, err := c.rc.GetService(service)
if err != nil { if err != nil {
if err == registry.ErrNotFound { if errors.Is(err, registry.ErrNotFound) {
return nil, ErrNotFound return nil, ErrNotFound
} }
return nil, err return nil, err
} }
@ -87,6 +100,7 @@ func (c *registrySelector) String() string {
return "registry" return "registry"
} }
// NewSelector creates a new default selector.
func NewSelector(opts ...Option) Selector { func NewSelector(opts ...Option) Selector {
sopts := Options{ sopts := Options{
Strategy: Random, Strategy: Random,

View File

@ -6,12 +6,13 @@ import (
) )
type serverKey struct{} type serverKey struct{}
type wgKey struct{}
func wait(ctx context.Context) *sync.WaitGroup { func wait(ctx context.Context) *sync.WaitGroup {
if ctx == nil { if ctx == nil {
return nil return nil
} }
wg, ok := ctx.Value("wait").(*sync.WaitGroup) wg, ok := ctx.Value(wgKey{}).(*sync.WaitGroup)
if !ok { if !ok {
return nil return nil
} }

View File

@ -20,7 +20,7 @@ type RouterOptions struct {
type RouterOption func(o *RouterOptions) type RouterOption func(o *RouterOptions)
func newRouterOptions(opt ...RouterOption) RouterOptions { func NewRouterOptions(opt ...RouterOption) RouterOptions {
opts := RouterOptions{ opts := RouterOptions{
Logger: logger.DefaultLogger, Logger: logger.DefaultLogger,
} }
@ -74,7 +74,8 @@ type Options struct {
Context context.Context Context context.Context
} }
func newOptions(opt ...Option) Options { // NewOptions creates new server options.
func NewOptions(opt ...Option) Options {
opts := Options{ opts := Options{
Codecs: make(map[string]codec.NewCodec), Codecs: make(map[string]codec.NewCodec),
Metadata: map[string]string{}, Metadata: map[string]string{},
@ -275,7 +276,7 @@ func Wait(wg *sync.WaitGroup) Option {
if wg == nil { if wg == nil {
wg = new(sync.WaitGroup) wg = new(sync.WaitGroup)
} }
o.Context = context.WithValue(o.Context, "wait", wg) o.Context = context.WithValue(o.Context, wgKey{}, wg)
} }
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/oxtoacart/bpool" "github.com/oxtoacart/bpool"
"github.com/pkg/errors" "github.com/pkg/errors"
"go-micro.dev/v4/codec" "go-micro.dev/v4/codec"
raw "go-micro.dev/v4/codec/bytes" raw "go-micro.dev/v4/codec/bytes"
"go-micro.dev/v4/codec/grpc" "go-micro.dev/v4/codec/grpc"
@ -14,6 +15,7 @@ import (
"go-micro.dev/v4/codec/proto" "go-micro.dev/v4/codec/proto"
"go-micro.dev/v4/codec/protorpc" "go-micro.dev/v4/codec/protorpc"
"go-micro.dev/v4/transport" "go-micro.dev/v4/transport"
"go-micro.dev/v4/transport/headers"
) )
type rpcCodec struct { type rpcCodec struct {
@ -36,6 +38,7 @@ type readWriteCloser struct {
} }
var ( var (
// DefaultContentType is the default codec content type.
DefaultContentType = "application/protobuf" DefaultContentType = "application/protobuf"
DefaultCodecs = map[string]codec.NewCodec{ DefaultCodecs = map[string]codec.NewCodec{
@ -65,12 +68,14 @@ var (
func (rwc *readWriteCloser) Read(p []byte) (n int, err error) { func (rwc *readWriteCloser) Read(p []byte) (n int, err error) {
rwc.RLock() rwc.RLock()
defer rwc.RUnlock() defer rwc.RUnlock()
return rwc.rbuf.Read(p) return rwc.rbuf.Read(p)
} }
func (rwc *readWriteCloser) Write(p []byte) (n int, err error) { func (rwc *readWriteCloser) Write(p []byte) (n int, err error) {
rwc.Lock() rwc.Lock()
defer rwc.Unlock() defer rwc.Unlock()
return rwc.wbuf.Write(p) return rwc.wbuf.Write(p)
} }
@ -82,6 +87,7 @@ func getHeader(hdr string, md map[string]string) string {
if hd := md[hdr]; len(hd) > 0 { if hd := md[hdr]; len(hd) > 0 {
return hd return hd
} }
return md["X-"+hdr] return md["X-"+hdr]
} }
@ -90,14 +96,15 @@ func getHeaders(m *codec.Message) {
if len(v) > 0 { if len(v) > 0 {
return v return v
} }
return m.Header[hdr] return m.Header[hdr]
} }
m.Id = set(m.Id, "Micro-Id") m.Id = set(m.Id, headers.ID)
m.Error = set(m.Error, "Micro-Error") m.Error = set(m.Error, headers.Error)
m.Endpoint = set(m.Endpoint, "Micro-Endpoint") m.Endpoint = set(m.Endpoint, headers.Endpoint)
m.Method = set(m.Method, "Micro-Method") m.Method = set(m.Method, headers.Method)
m.Target = set(m.Target, "Micro-Service") m.Target = set(m.Target, headers.Request)
// TODO: remove this cruft // TODO: remove this cruft
if len(m.Endpoint) == 0 { if len(m.Endpoint) == 0 {
@ -110,26 +117,27 @@ func setHeaders(m, r *codec.Message) {
if len(v) == 0 { if len(v) == 0 {
return return
} }
m.Header[hdr] = v m.Header[hdr] = v
m.Header["X-"+hdr] = v m.Header["X-"+hdr] = v
} }
// set headers // set headers
set("Micro-Id", r.Id) set(headers.ID, r.Id)
set("Micro-Service", r.Target) set(headers.Request, r.Target)
set("Micro-Method", r.Method) set(headers.Method, r.Method)
set("Micro-Endpoint", r.Endpoint) set(headers.Endpoint, r.Endpoint)
set("Micro-Error", r.Error) set(headers.Error, r.Error)
} }
// setupProtocol sets up the old protocol. // setupProtocol sets up the old protocol.
func setupProtocol(msg *transport.Message) codec.NewCodec { func setupProtocol(msg *transport.Message) codec.NewCodec {
service := getHeader("Micro-Service", msg.Header) service := getHeader(headers.Request, msg.Header)
method := getHeader("Micro-Method", msg.Header) method := getHeader(headers.Method, msg.Header)
endpoint := getHeader("Micro-Endpoint", msg.Header) endpoint := getHeader(headers.Endpoint, msg.Header)
protocol := getHeader("Micro-Protocol", msg.Header) protocol := getHeader(headers.Protocol, msg.Header)
target := getHeader("Micro-Target", msg.Header) target := getHeader(headers.Target, msg.Header)
topic := getHeader("Micro-Topic", msg.Header) topic := getHeader(headers.Message, msg.Header)
// if the protocol exists (mucp) do nothing // if the protocol exists (mucp) do nothing
if len(protocol) > 0 { if len(protocol) > 0 {
@ -153,18 +161,18 @@ func setupProtocol(msg *transport.Message) codec.NewCodec {
// no method then set to endpoint // no method then set to endpoint
if len(method) == 0 { if len(method) == 0 {
msg.Header["Micro-Method"] = endpoint msg.Header[headers.Method] = endpoint
} }
// no endpoint then set to method // no endpoint then set to method
if len(endpoint) == 0 { if len(endpoint) == 0 {
msg.Header["Micro-Endpoint"] = method msg.Header[headers.Endpoint] = method
} }
return nil return nil
} }
func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCodec) codec.Codec { func newRPCCodec(req *transport.Message, socket transport.Socket, c codec.NewCodec) codec.Codec {
rwc := &readWriteCloser{ rwc := &readWriteCloser{
rbuf: bufferPool.Get(), rbuf: bufferPool.Get(),
wbuf: bufferPool.Get(), wbuf: bufferPool.Get(),
@ -185,7 +193,6 @@ func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCod
case "grpc": case "grpc":
// write the body // write the body
rwc.rbuf.Write(req.Body) rwc.rbuf.Write(req.Body)
// set the protocol
r.protocol = "grpc" r.protocol = "grpc"
default: default:
// first is not preloaded // first is not preloaded
@ -197,7 +204,7 @@ func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCod
func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error { func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error {
// the initial message // the initial message
m := codec.Message{ mmsg := codec.Message{
Header: c.req.Header, Header: c.req.Header,
Body: c.req.Body, Body: c.req.Body,
} }
@ -221,9 +228,9 @@ func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error {
} }
// set the message header // set the message header
m.Header = tm.Header mmsg.Header = tm.Header
// set the message body // set the message body
m.Body = tm.Body mmsg.Body = tm.Body
// set req // set req
c.req = &tm c.req = &tm
@ -248,20 +255,20 @@ func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error {
} }
// set some internal things // set some internal things
getHeaders(&m) getHeaders(&mmsg)
// read header via codec // read header via codec
if err := c.codec.ReadHeader(&m, codec.Request); err != nil { if err := c.codec.ReadHeader(&mmsg, codec.Request); err != nil {
return err return err
} }
// fallback for 0.14 and older // fallback for 0.14 and older
if len(m.Endpoint) == 0 { if len(mmsg.Endpoint) == 0 {
m.Endpoint = m.Method mmsg.Endpoint = mmsg.Method
} }
// set message // set message
*r = m *r = mmsg
return nil return nil
} }
@ -315,7 +322,7 @@ func (c *rpcCodec) Write(r *codec.Message, b interface{}) error {
// write an error if it failed // write an error if it failed
m.Error = errors.Wrapf(err, "Unable to encode body").Error() m.Error = errors.Wrapf(err, "Unable to encode body").Error()
m.Header["Micro-Error"] = m.Error m.Header[headers.Error] = m.Error
// no body to write // no body to write
if err := c.codec.Write(m, nil); err != nil { if err := c.codec.Write(m, nil); err != nil {
return err return err

View File

@ -3,6 +3,7 @@ package server
import ( import (
"go-micro.dev/v4/broker" "go-micro.dev/v4/broker"
"go-micro.dev/v4/transport" "go-micro.dev/v4/transport"
"go-micro.dev/v4/transport/headers"
) )
// event is a broker event we handle on the server transport. // event is a broker event we handle on the server transport.
@ -25,7 +26,7 @@ func (e *event) Error() error {
} }
func (e *event) Topic() string { func (e *event) Topic() string {
return e.message.Header["Micro-Topic"] return e.message.Header[headers.Message]
} }
func newEvent(msg transport.Message) *event { func newEvent(msg transport.Message) *event {

143
server/rpc_events.go Normal file
View File

@ -0,0 +1,143 @@
package server
import (
"context"
"fmt"
"go-micro.dev/v4/broker"
raw "go-micro.dev/v4/codec/bytes"
log "go-micro.dev/v4/logger"
"go-micro.dev/v4/metadata"
"go-micro.dev/v4/transport/headers"
)
// HandleEvent handles inbound messages to the service directly.
// These events are a result of registering to the topic with the service name.
// TODO: handle requests from an event. We won't send a response.
func (s *rpcServer) HandleEvent(e broker.Event) error {
// formatting horrible cruft
msg := e.Message()
if msg.Header == nil {
msg.Header = make(map[string]string)
}
contentType, ok := msg.Header["Content-Type"]
if !ok || len(contentType) == 0 {
msg.Header["Content-Type"] = DefaultContentType
contentType = DefaultContentType
}
cf, err := s.newCodec(contentType)
if err != nil {
return err
}
header := make(map[string]string, len(msg.Header))
for k, v := range msg.Header {
header[k] = v
}
// create context
ctx := metadata.NewContext(context.Background(), header)
// TODO: inspect message header for Micro-Service & Micro-Topic
rpcMsg := &rpcMessage{
topic: msg.Header[headers.Message],
contentType: contentType,
payload: &raw.Frame{Data: msg.Body},
codec: cf,
header: msg.Header,
body: msg.Body,
}
// if the router is present then execute it
r := Router(s.router)
if s.opts.Router != nil {
// create a wrapped function
handler := s.opts.Router.ProcessMessage
// execute the wrapper for it
for i := len(s.opts.SubWrappers); i > 0; i-- {
handler = s.opts.SubWrappers[i-1](handler)
}
// set the router
r = rpcRouter{m: handler}
}
return r.ProcessMessage(ctx, rpcMsg)
}
func (s *rpcServer) NewSubscriber(topic string, sb interface{}, opts ...SubscriberOption) Subscriber {
return s.router.NewSubscriber(topic, sb, opts...)
}
func (s *rpcServer) Subscribe(sb Subscriber) error {
s.Lock()
defer s.Unlock()
sub, ok := sb.(*subscriber)
if !ok {
return fmt.Errorf("invalid subscriber: expected *subscriber")
}
if len(sub.handlers) == 0 {
return fmt.Errorf("invalid subscriber: no handler functions")
}
if err := validateSubscriber(sub); err != nil {
return err
}
// append to subscribers
// subs := s.subscribers[sub.Topic()]
// subs = append(subs, sub)
// router.subscribers[sub.Topic()] = subs
s.subscribers[sb] = nil
return nil
}
// subscribeServer will subscribe the server to the topic with its own name.
func (s *rpcServer) subscribeServer(config Options) error {
if s.opts.Router != nil {
sub, err := s.opts.Broker.Subscribe(config.Name, s.HandleEvent)
if err != nil {
return err
}
// Save the subscriber
s.subscriber = sub
}
return nil
}
// reSubscribe itterates over subscribers and re-subscribes then.
func (s *rpcServer) reSubscribe(config Options) error {
for sb := range s.subscribers {
var opts []broker.SubscribeOption
if queue := sb.Options().Queue; len(queue) > 0 {
opts = append(opts, broker.Queue(queue))
}
if ctx := sb.Options().Context; ctx != nil {
opts = append(opts, broker.SubscribeContext(ctx))
}
if !sb.Options().AutoAck {
opts = append(opts, broker.DisableAutoAck())
}
config.Logger.Logf(log.InfoLevel, "Subscribing to topic: %s", sb.Topic())
sub, err := config.Broker.Subscribe(sb.Topic(), s.HandleEvent, opts...)
if err != nil {
return err
}
s.subscribers[sb] = []broker.Subscriber{sub}
}
return nil
}

View File

@ -6,14 +6,14 @@ import (
"go-micro.dev/v4/registry" "go-micro.dev/v4/registry"
) )
type rpcHandler struct { type RpcHandler struct {
name string name string
handler interface{} handler interface{}
endpoints []*registry.Endpoint endpoints []*registry.Endpoint
opts HandlerOptions opts HandlerOptions
} }
func newRpcHandler(handler interface{}, opts ...HandlerOption) Handler { func NewRpcHandler(handler interface{}, opts ...HandlerOption) Handler {
options := HandlerOptions{ options := HandlerOptions{
Metadata: make(map[string]map[string]string), Metadata: make(map[string]map[string]string),
} }
@ -40,7 +40,7 @@ func newRpcHandler(handler interface{}, opts ...HandlerOption) Handler {
} }
} }
return &rpcHandler{ return &RpcHandler{
name: name, name: name,
handler: handler, handler: handler,
endpoints: endpoints, endpoints: endpoints,
@ -48,18 +48,18 @@ func newRpcHandler(handler interface{}, opts ...HandlerOption) Handler {
} }
} }
func (r *rpcHandler) Name() string { func (r *RpcHandler) Name() string {
return r.name return r.name
} }
func (r *rpcHandler) Handler() interface{} { func (r *RpcHandler) Handler() interface{} {
return r.handler return r.handler
} }
func (r *rpcHandler) Endpoints() []*registry.Endpoint { func (r *RpcHandler) Endpoints() []*registry.Endpoint {
return r.endpoints return r.endpoints
} }
func (r *rpcHandler) Options() HandlerOptions { func (r *RpcHandler) Options() HandlerOptions {
return r.opts return r.opts
} }

101
server/rpc_helper.go Normal file
View File

@ -0,0 +1,101 @@
package server
import (
"fmt"
"sync"
"go-micro.dev/v4/codec"
"go-micro.dev/v4/registry"
)
// setRegistered will set the service as registered safely.
func (s *rpcServer) setRegistered(b bool) {
s.Lock()
defer s.Unlock()
s.registered = b
}
// isRegistered will check if the service has already been registered.
func (s *rpcServer) isRegistered() bool {
s.RLock()
defer s.RUnlock()
return s.registered
}
// setStarted will set started state safely.
func (s *rpcServer) setStarted(b bool) {
s.Lock()
defer s.Unlock()
s.started = b
}
// isStarted will check if the service has already been started.
func (s *rpcServer) isStarted() bool {
s.RLock()
defer s.RUnlock()
return s.started
}
// setWg will set the waitgroup safely.
func (s *rpcServer) setWg(wg *sync.WaitGroup) {
s.Lock()
defer s.Unlock()
s.wg = wg
}
// getWaitgroup returns the global waitgroup safely.
func (s *rpcServer) getWg() *sync.WaitGroup {
s.RLock()
defer s.RUnlock()
return s.wg
}
// setOptsAddr will set the address in the service options safely.
func (s *rpcServer) setOptsAddr(addr string) {
s.Lock()
defer s.Unlock()
s.opts.Address = addr
}
func (s *rpcServer) getCachedService() *registry.Service {
s.RLock()
defer s.RUnlock()
return s.rsvc
}
func (s *rpcServer) Options() Options {
s.RLock()
defer s.RUnlock()
return s.opts
}
// swapAddr swaps the address found in the config and the transport address.
func (s *rpcServer) swapAddr(config Options, addr string) string {
s.Lock()
defer s.Unlock()
a := config.Address
s.opts.Address = addr
return a
}
func (s *rpcServer) newCodec(contentType string) (codec.NewCodec, error) {
if cf, ok := s.opts.Codecs[contentType]; ok {
return cf, nil
}
if cf, ok := DefaultCodecs[contentType]; ok {
return cf, nil
}
return nil, fmt.Errorf("unsupported Content-Type: %s", contentType)
}

View File

@ -1,11 +1,5 @@
package server package server
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// Meh, we need to get rid of this shit
import ( import (
"context" "context"
"errors" "errors"
@ -80,7 +74,7 @@ type router struct {
subscribers map[string][]*subscriber subscribers map[string][]*subscriber
} }
// rpcRouter encapsulates functions that become a server.Router. // rpcRouter encapsulates functions that become a Router.
type rpcRouter struct { type rpcRouter struct {
h func(context.Context, Request, interface{}) error h func(context.Context, Request, interface{}) error
m func(context.Context, Message) error m func(context.Context, Message) error
@ -96,7 +90,7 @@ func (r rpcRouter) ServeRequest(ctx context.Context, req Request, rsp Response)
func newRpcRouter(opts ...RouterOption) *router { func newRpcRouter(opts ...RouterOption) *router {
return &router{ return &router{
ops: newRouterOptions(opts...), ops: NewRouterOptions(opts...),
serviceMap: make(map[string]*service), serviceMap: make(map[string]*service),
subscribers: make(map[string][]*subscriber), subscribers: make(map[string][]*subscriber),
} }
@ -180,11 +174,13 @@ func prepareMethod(method reflect.Method, logger log.Logger) *methodType {
logger.Logf(log.ErrorLevel, "method %v has wrong number of outs: %v", mname, mtype.NumOut()) logger.Logf(log.ErrorLevel, "method %v has wrong number of outs: %v", mname, mtype.NumOut())
return nil return nil
} }
// The return type of the method must be error. // The return type of the method must be error.
if returnType := mtype.Out(0); returnType != typeOfError { if returnType := mtype.Out(0); returnType != typeOfError {
logger.Logf(log.ErrorLevel, "method %v returns %v not error", mname, returnType.String()) logger.Logf(log.ErrorLevel, "method %v returns %v not error", mname, returnType.String())
return nil return nil
} }
return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream} return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
} }
@ -195,10 +191,13 @@ func (router *router) sendResponse(sending sync.Locker, req *request, reply inte
resp.msg = msg resp.msg = msg
resp.msg.Id = req.msg.Id resp.msg.Id = req.msg.Id
sending.Lock() sending.Lock()
err := cc.Write(resp.msg, reply) err := cc.Write(resp.msg, reply)
sending.Unlock() sending.Unlock()
router.freeResponse(resp) router.freeResponse(resp)
return err return err
} }
@ -261,6 +260,7 @@ func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex,
// Invoke the method, providing a new value for the reply. // Invoke the method, providing a new value for the reply.
fn := func(ctx context.Context, req Request, stream interface{}) error { fn := func(ctx context.Context, req Request, stream interface{}) error {
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)}) returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)})
if err := returnValues[0].Interface(); err != nil { if err := returnValues[0].Interface(); err != nil {
// the function returned an error, we use that // the function returned an error, we use that
return err.(error) return err.(error)
@ -288,11 +288,14 @@ func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
if contextv := reflect.ValueOf(ctx); contextv.IsValid() { if contextv := reflect.ValueOf(ctx); contextv.IsValid() {
return contextv return contextv
} }
return reflect.Zero(m.ContextType) return reflect.Zero(m.ContextType)
} }
func (router *router) getRequest() *request { func (router *router) getRequest() *request {
router.reqLock.Lock() router.reqLock.Lock()
defer router.reqLock.Unlock()
req := router.freeReq req := router.freeReq
if req == nil { if req == nil {
req = new(request) req = new(request)
@ -300,19 +303,22 @@ func (router *router) getRequest() *request {
router.freeReq = req.next router.freeReq = req.next
*req = request{} *req = request{}
} }
router.reqLock.Unlock()
return req return req
} }
func (router *router) freeRequest(req *request) { func (router *router) freeRequest(req *request) {
router.reqLock.Lock() router.reqLock.Lock()
defer router.reqLock.Unlock()
req.next = router.freeReq req.next = router.freeReq
router.freeReq = req router.freeReq = req
router.reqLock.Unlock()
} }
func (router *router) getResponse() *response { func (router *router) getResponse() *response {
router.respLock.Lock() router.respLock.Lock()
defer router.respLock.Unlock()
resp := router.freeResp resp := router.freeResp
if resp == nil { if resp == nil {
resp = new(response) resp = new(response)
@ -320,15 +326,16 @@ func (router *router) getResponse() *response {
router.freeResp = resp.next router.freeResp = resp.next
*resp = response{} *resp = response{}
} }
router.respLock.Unlock()
return resp return resp
} }
func (router *router) freeResponse(resp *response) { func (router *router) freeResponse(resp *response) {
router.respLock.Lock() router.respLock.Lock()
defer router.respLock.Unlock()
resp.next = router.freeResp resp.next = router.freeResp
router.freeResp = resp router.freeResp = resp
router.respLock.Unlock()
} }
func (router *router) readRequest(r Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) { func (router *router) readRequest(r Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) {
@ -341,8 +348,10 @@ func (router *router) readRequest(r Request) (service *service, mtype *methodTyp
} }
// discard body // discard body
cc.ReadBody(nil) cc.ReadBody(nil)
return return
} }
// is it a streaming request? then we don't read the body // is it a streaming request? then we don't read the body
if mtype.stream { if mtype.stream {
if cc.(codec.Codec).String() != "grpc" { if cc.(codec.Codec).String() != "grpc" {
@ -359,10 +368,12 @@ func (router *router) readRequest(r Request) (service *service, mtype *methodTyp
argv = reflect.New(mtype.ArgType) argv = reflect.New(mtype.ArgType)
argIsValue = true argIsValue = true
} }
// argv guaranteed to be a pointer now. // argv guaranteed to be a pointer now.
if err = cc.ReadBody(argv.Interface()); err != nil { if err = cc.ReadBody(argv.Interface()); err != nil {
return return
} }
if argIsValue { if argIsValue {
argv = argv.Elem() argv = argv.Elem()
} }
@ -370,6 +381,7 @@ func (router *router) readRequest(r Request) (service *service, mtype *methodTyp
if !mtype.stream { if !mtype.stream {
replyv = reflect.New(mtype.ReplyType.Elem()) replyv = reflect.New(mtype.ReplyType.Elem())
} }
return return
} }
@ -387,6 +399,7 @@ func (router *router) readHeader(cc codec.Reader) (service *service, mtype *meth
return return
} }
err = errors.New("rpc: router cannot decode request: " + err.Error()) err = errors.New("rpc: router cannot decode request: " + err.Error())
return return
} }
@ -399,28 +412,33 @@ func (router *router) readHeader(cc codec.Reader) (service *service, mtype *meth
err = errors.New("rpc: service/endpoint request ill-formed: " + req.msg.Endpoint) err = errors.New("rpc: service/endpoint request ill-formed: " + req.msg.Endpoint)
return return
} }
// Look up the request. // Look up the request.
router.mu.Lock() router.mu.Lock()
service = router.serviceMap[serviceMethod[0]] service = router.serviceMap[serviceMethod[0]]
router.mu.Unlock() router.mu.Unlock()
if service == nil { if service == nil {
err = errors.New("rpc: can't find service " + serviceMethod[0]) err = errors.New("rpc: can't find service " + serviceMethod[0])
return return
} }
mtype = service.method[serviceMethod[1]] mtype = service.method[serviceMethod[1]]
if mtype == nil { if mtype == nil {
err = errors.New("rpc: can't find method " + serviceMethod[1]) err = errors.New("rpc: can't find method " + serviceMethod[1])
} }
return return
} }
func (router *router) NewHandler(h interface{}, opts ...HandlerOption) Handler { func (router *router) NewHandler(h interface{}, opts ...HandlerOption) Handler {
return newRpcHandler(h, opts...) return NewRpcHandler(h, opts...)
} }
func (router *router) Handle(h Handler) error { func (router *router) Handle(h Handler) error {
router.mu.Lock() router.mu.Lock()
defer router.mu.Unlock() defer router.mu.Unlock()
if router.serviceMap == nil { if router.serviceMap == nil {
router.serviceMap = make(map[string]*service) router.serviceMap = make(map[string]*service)
} }
@ -428,6 +446,7 @@ func (router *router) Handle(h Handler) error {
if len(h.Name()) == 0 { if len(h.Name()) == 0 {
return errors.New("rpc.Handle: handler has no name") return errors.New("rpc.Handle: handler has no name")
} }
if !isExported(h.Name()) { if !isExported(h.Name()) {
return errors.New("rpc.Handle: type " + h.Name() + " is not exported") return errors.New("rpc.Handle: type " + h.Name() + " is not exported")
} }
@ -460,6 +479,7 @@ func (router *router) Handle(h Handler) error {
// save handler // save handler
router.serviceMap[s.name] = s router.serviceMap[s.name] = s
return nil return nil
} }
@ -474,8 +494,10 @@ func (router *router) ServeRequest(ctx context.Context, r Request, rsp Response)
if req != nil { if req != nil {
router.freeRequest(req) router.freeRequest(req)
} }
return err return err
} }
return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec()) return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec())
} }
@ -488,6 +510,7 @@ func (router *router) Subscribe(s Subscriber) error {
if !ok { if !ok {
return fmt.Errorf("invalid subscriber: expected *subscriber") return fmt.Errorf("invalid subscriber: expected *subscriber")
} }
if len(sub.handlers) == 0 { if len(sub.handlers) == 0 {
return fmt.Errorf("invalid subscriber: no handler functions") return fmt.Errorf("invalid subscriber: no handler functions")
} }
@ -517,10 +540,9 @@ func (router *router) ProcessMessage(ctx context.Context, msg Message) (err erro
} }
}() }()
router.su.RLock()
// get the subscribers by topic // get the subscribers by topic
router.su.RLock()
subs, ok := router.subscribers[msg.Topic()] subs, ok := router.subscribers[msg.Topic()]
// unlock since we only need to get the subs
router.su.RUnlock() router.su.RUnlock()
if !ok { if !ok {
return nil return nil

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,13 @@ type waitGroup struct {
gg *sync.WaitGroup gg *sync.WaitGroup
} }
// NewWaitGroup returns a new double waitgroup for global management of processes.
func NewWaitGroup(gWg *sync.WaitGroup) *waitGroup {
return &waitGroup{
gg: gWg,
}
}
func (w *waitGroup) Add(i int) { func (w *waitGroup) Add(i int) {
w.lg.Add(i) w.lg.Add(i)
if w.gg != nil { if w.gg != nil {

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"go-micro.dev/v4/codec" "go-micro.dev/v4/codec"
log "go-micro.dev/v4/logger" log "go-micro.dev/v4/logger"
"go-micro.dev/v4/registry" "go-micro.dev/v4/registry"
@ -140,14 +141,14 @@ var (
DefaultName = "go.micro.server" DefaultName = "go.micro.server"
DefaultVersion = "latest" DefaultVersion = "latest"
DefaultId = uuid.New().String() DefaultId = uuid.New().String()
DefaultServer Server = newRpcServer() DefaultServer Server = NewRPCServer()
DefaultRouter = newRpcRouter() DefaultRouter = newRpcRouter()
DefaultRegisterCheck = func(context.Context) error { return nil } DefaultRegisterCheck = func(context.Context) error { return nil }
DefaultRegisterInterval = time.Second * 30 DefaultRegisterInterval = time.Second * 30
DefaultRegisterTTL = time.Second * 90 DefaultRegisterTTL = time.Second * 90
// NewServer creates a new server. // NewServer creates a new server.
NewServer func(...Option) Server = newRpcServer NewServer func(...Option) Server = NewRPCServer
) )
// DefaultOptions returns config options for the default service. // DefaultOptions returns config options for the default service.
@ -157,7 +158,7 @@ func DefaultOptions() Options {
func Init(opt ...Option) { func Init(opt ...Option) {
if DefaultServer == nil { if DefaultServer == nil {
DefaultServer = newRpcServer(opt...) DefaultServer = NewRPCServer(opt...)
} }
DefaultServer.Init(opt...) DefaultServer.Init(opt...)
} }

View File

@ -113,7 +113,7 @@ func (s *service) Stop() error {
err = fn() err = fn()
} }
if err = s.opts.Server.Stop(); err != nil { if err := s.opts.Server.Stop(); err != nil {
return err return err
} }
@ -144,6 +144,7 @@ func (s *service) Run() (err error) {
if err = s.opts.Profile.Start(); err != nil { if err = s.opts.Profile.Start(); err != nil {
return err return err
} }
defer func() { defer func() {
err = s.opts.Profile.Stop() err = s.opts.Profile.Stop()
if err != nil { if err != nil {

View File

@ -1,286 +0,0 @@
package micro
import (
"context"
"errors"
"net"
"sync"
"testing"
"go-micro.dev/v4/client"
"go-micro.dev/v4/debug/handler"
proto "go-micro.dev/v4/debug/proto"
"go-micro.dev/v4/registry"
"go-micro.dev/v4/server"
"go-micro.dev/v4/transport"
"go-micro.dev/v4/util/test"
)
func testShutdown(wg *sync.WaitGroup, cancel func()) {
// add 1
wg.Add(1)
// shutdown the service
cancel()
// wait for stop
wg.Wait()
}
func testService(t testing.TB, ctx context.Context, wg *sync.WaitGroup, name string) Service {
// add self
wg.Add(1)
r := registry.NewMemoryRegistry(registry.Services(test.Data))
// create service
srv := NewService(
Name(name),
Context(ctx),
Registry(r),
AfterStart(func() error {
wg.Done()
return nil
}),
AfterStop(func() error {
wg.Done()
return nil
}),
)
if err := RegisterHandler(srv.Server(), handler.NewHandler(srv.Client())); err != nil {
t.Fatal(err)
}
return srv
}
func testCustomListenService(ctx context.Context, customListener net.Listener, wg *sync.WaitGroup, name string) Service {
// add self
wg.Add(1)
r := registry.NewMemoryRegistry(registry.Services(test.Data))
// create service
srv := NewService(
Name(name),
Context(ctx),
Registry(r),
// injection customListener
AddListenOption(server.ListenOption(transport.NetListener(customListener))),
AfterStart(func() error {
wg.Done()
return nil
}),
AfterStop(func() error {
wg.Done()
return nil
}),
)
RegisterHandler(srv.Server(), handler.NewHandler(srv.Client()))
return srv
}
func testRequest(ctx context.Context, c client.Client, name string) error {
// test call debug
req := c.NewRequest(
name,
"Debug.Health",
new(proto.HealthRequest),
)
rsp := new(proto.HealthResponse)
err := c.Call(context.TODO(), req, rsp)
if err != nil {
return err
}
if rsp.Status != "ok" {
return errors.New("service response: " + rsp.Status)
}
return nil
}
// TestService tests running and calling a service.
func TestService(t *testing.T) {
// waitgroup for server start
var wg sync.WaitGroup
// cancellation context
ctx, cancel := context.WithCancel(context.Background())
// start test server
service := testService(t, ctx, &wg, "test.service")
go func() {
// wait for service to start
wg.Wait()
// make a test call
if err := testRequest(ctx, service.Client(), "test.service"); err != nil {
t.Fatal(err)
}
// shutdown the service
testShutdown(&wg, cancel)
}()
// start service
if err := service.Run(); err != nil {
t.Fatal(err)
}
}
func benchmarkCustomListenService(b *testing.B, n int, name string) {
// create custom listen
customListen, err := net.Listen("tcp", server.DefaultAddress)
if err != nil {
b.Fatal(err)
}
// stop the timer
b.StopTimer()
// waitgroup for server start
var wg sync.WaitGroup
// cancellation context
ctx, cancel := context.WithCancel(context.Background())
// create test server
service := testCustomListenService(ctx, customListen, &wg, name)
// start the server
go func() {
if err := service.Run(); err != nil {
b.Fatal(err)
}
}()
// wait for service to start
wg.Wait()
// make a test call to warm the cache
for i := 0; i < 10; i++ {
if err := testRequest(ctx, service.Client(), name); err != nil {
b.Fatal(err)
}
}
// start the timer
b.StartTimer()
// number of iterations
for i := 0; i < b.N; i++ {
// for concurrency
for j := 0; j < n; j++ {
wg.Add(1)
go func() {
err := testRequest(ctx, service.Client(), name)
wg.Done()
if err != nil {
b.Fatal(err)
}
}()
}
// wait for test completion
wg.Wait()
}
// stop the timer
b.StopTimer()
// shutdown service
testShutdown(&wg, cancel)
}
func benchmarkService(b *testing.B, n int, name string) {
// stop the timer
b.StopTimer()
// waitgroup for server start
var wg sync.WaitGroup
// cancellation context
ctx, cancel := context.WithCancel(context.Background())
// create test server
service := testService(b, ctx, &wg, name)
// start the server
go func() {
if err := service.Run(); err != nil {
b.Fatal(err)
}
}()
// wait for service to start
wg.Wait()
// make a test call to warm the cache
for i := 0; i < 10; i++ {
if err := testRequest(ctx, service.Client(), name); err != nil {
b.Fatal(err)
}
}
// start the timer
b.StartTimer()
// number of iterations
for i := 0; i < b.N; i++ {
// for concurrency
for j := 0; j < n; j++ {
wg.Add(1)
go func() {
err := testRequest(ctx, service.Client(), name)
wg.Done()
if err != nil {
b.Fatal(err)
}
}()
}
// wait for test completion
wg.Wait()
}
// stop the timer
b.StopTimer()
// shutdown service
testShutdown(&wg, cancel)
}
func BenchmarkService1(b *testing.B) {
benchmarkService(b, 1, "test.service.1")
}
func BenchmarkService8(b *testing.B) {
benchmarkService(b, 8, "test.service.8")
}
func BenchmarkService16(b *testing.B) {
benchmarkService(b, 16, "test.service.16")
}
func BenchmarkService32(b *testing.B) {
benchmarkService(b, 32, "test.service.32")
}
func BenchmarkService64(b *testing.B) {
benchmarkService(b, 64, "test.service.64")
}
func BenchmarkCustomListenService1(b *testing.B) {
benchmarkCustomListenService(b, 1, "test.service.1")
}

64
tests/default_test.go Normal file
View File

@ -0,0 +1,64 @@
package tests
import (
"testing"
"go-micro.dev/v4"
"go-micro.dev/v4/broker"
"go-micro.dev/v4/client"
"go-micro.dev/v4/registry"
"go-micro.dev/v4/server"
"go-micro.dev/v4/transport"
"go-micro.dev/v4/util/test"
)
func BenchmarkService(b *testing.B) {
cfg := ServiceTestConfig{
Name: "test-service",
NewService: newService,
Parallel: []int{1, 8, 16, 32, 64},
Sequential: []int{0},
Streams: []int{0},
// PubSub: []int{10},
}
cfg.Run(b)
}
func newService(name string, opts ...micro.Option) (micro.Service, error) {
r := registry.NewMemoryRegistry(
registry.Services(test.Data),
)
b := broker.NewMemoryBroker()
t := transport.NewHTTPTransport()
c := client.NewClient(
client.Transport(t),
client.Broker(b),
)
s := server.NewRPCServer(
server.Name(name),
server.Registry(r),
server.Transport(t),
server.Broker(b),
)
if err := s.Init(); err != nil {
return nil, err
}
options := []micro.Option{
micro.Name(name),
micro.Server(s),
micro.Client(c),
micro.Registry(r),
micro.Broker(b),
}
options = append(options, opts...)
srv := micro.NewService(options...)
return srv, nil
}

395
tests/service.go Normal file
View File

@ -0,0 +1,395 @@
// Package tests implements a testing framwork, and provides default tests.
package tests
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/pkg/errors"
proto "github.com/go-micro/plugins/v4/server/grpc/proto"
"go-micro.dev/v4"
"go-micro.dev/v4/client"
"go-micro.dev/v4/debug/handler"
pb "go-micro.dev/v4/debug/proto"
)
var (
// ErrNoTests returns no test params are set.
ErrNoTests = errors.New("No tests to run, all values set to 0")
testTopic = "Test-Topic"
errorTopic = "Error-Topic"
)
type parTest func(name string, c client.Client, p, s int, errChan chan error)
type testFunc func(name string, c client.Client, errChan chan error)
// ServiceTestConfig allows you to easily test a service configuration by
// running predefined tests against your custom service. You only need to
// provide a function to create the service, and how many of which test you
// want to run.
//
// The default tests provided, all running with separate parallel routines are:
// - Sequential Call requests
// - Bi-directional streaming
// - Pub/Sub events brokering
//
// You can provide an array of parallel routines to run for the request and
// stream tests. They will be run as matrix tests, so with each possible combination.
// Thus, in total (p * seq) + (p * streams) tests will be run.
type ServiceTestConfig struct {
// Service name to use for the tests
Name string
// NewService function will be called to setup the new service.
// It takes in a list of options, which by default will Context and an
// AfterStart with channel to signal when the service has been started.
NewService func(name string, opts ...micro.Option) (micro.Service, error)
// Parallel is the number of prallell routines to use for the tests.
Parallel []int
// Sequential is the number of sequential requests to send per parallel process.
Sequential []int
// Streams is the nummber of streaming messages to send over the stream per routine.
Streams []int
// PubSub is the number of times to publish messages to the broker per routine.
PubSub []int
mu sync.Mutex
msgCount int
}
// Run will start the benchmark tests.
func (stc *ServiceTestConfig) Run(b *testing.B) {
if err := stc.validate(); err != nil {
b.Fatal("Failed to validate config", err)
}
// Run routines with sequential requests
stc.prepBench(b, "req", stc.runParSeqTest, stc.Sequential)
// Run routines with streams
stc.prepBench(b, "streams", stc.runParStreamTest, stc.Streams)
// Run routines with pub/sub
stc.prepBench(b, "pubsub", stc.runBrokerTest, stc.PubSub)
}
// prepBench will prepare the benmark by setting the right parameters,
// and invoking the test.
func (stc *ServiceTestConfig) prepBench(b *testing.B, tName string, test parTest, seq []int) {
par := stc.Parallel
// No requests needed
if len(seq) == 0 || seq[0] == 0 {
return
}
for _, parallel := range par {
for _, sequential := range seq {
// Create the service name for the test
name := fmt.Sprintf("%s.%dp-%d%s", stc.Name, parallel, sequential, tName)
// Run test with parallel routines making each sequential requests
test := func(name string, c client.Client, errChan chan error) {
test(name, c, parallel, sequential, errChan)
}
benchmark := func(b *testing.B) {
b.ReportAllocs()
stc.runBench(b, name, test)
}
b.Logf("----------- STARTING TEST %s -----------", name)
// Run test, return if it fails
if !b.Run(name, benchmark) {
return
}
}
}
}
// runParSeqTest will make s sequential requests in p parallel routines.
func (stc *ServiceTestConfig) runParSeqTest(name string, c client.Client, p, s int, errChan chan error) {
testParallel(p, func() {
// Make serial requests
for z := 0; z < s; z++ {
if err := testRequest(context.Background(), c, name); err != nil {
errChan <- errors.Wrapf(err, "[%s] Request failed during testRequest", name)
return
}
}
})
}
// Handle is used as a test handler.
func (stc *ServiceTestConfig) Handle(ctx context.Context, msg *proto.Request) error {
stc.mu.Lock()
stc.msgCount++
stc.mu.Unlock()
return nil
}
// HandleError is used as a test handler.
func (stc *ServiceTestConfig) HandleError(ctx context.Context, msg *proto.Request) error {
return errors.New("dummy error")
}
// runBrokerTest will publish messages to the broker to test pub/sub.
func (stc *ServiceTestConfig) runBrokerTest(name string, c client.Client, p, s int, errChan chan error) {
stc.msgCount = 0
testParallel(p, func() {
for z := 0; z < s; z++ {
msg := pb.BusMsg{Msg: "Hello from broker!"}
if err := c.Publish(context.Background(), c.NewMessage(testTopic, &msg)); err != nil {
errChan <- errors.Wrap(err, "failed to publish message to broker")
return
}
msg = pb.BusMsg{Msg: "Some message that will error"}
if err := c.Publish(context.Background(), c.NewMessage(errorTopic, &msg)); err == nil {
errChan <- errors.New("Publish is supposed to return an error, but got no error")
return
}
}
})
if stc.msgCount != s*p {
errChan <- fmt.Errorf("pub/sub does not work properly, invalid message count. Expected %d messaged, but received %d", s*p, stc.msgCount)
return
}
}
// runParStreamTest will start streaming, and send s messages parallel in p routines.
func (stc *ServiceTestConfig) runParStreamTest(name string, c client.Client, p, s int, errChan chan error) {
testParallel(p, func() {
// Create a client service
srv := pb.NewDebugService(name, c)
// Establish a connection to server over which we start streaming
bus, err := srv.MessageBus(context.Background())
if err != nil {
errChan <- errors.Wrap(err, "failed to connect to message bus")
return
}
// Start streaming requests
for z := 0; z < s; z++ {
if err := bus.Send(&pb.BusMsg{Msg: "Hack the world!"}); err != nil {
errChan <- errors.Wrap(err, "failed to send to stream")
return
}
msg, err := bus.Recv()
if err != nil {
errChan <- errors.Wrap(err, "failed to receive message from stream")
return
}
expected := "Request received!"
if msg.Msg != expected {
errChan <- fmt.Errorf("stream returned unexpected mesage. Expected '%s', but got '%s'", expected, msg.Msg)
return
}
}
})
}
// validate will make sure the provided test parameters are a legal combination.
func (stc *ServiceTestConfig) validate() error {
lp, lseq, lstr := len(stc.Parallel), len(stc.Sequential), len(stc.Streams)
if lp == 0 || (lseq == 0 && lstr == 0) {
return ErrNoTests
}
return nil
}
// runBench will create a service with the provided stc.NewService function,
// and run a benchmark on the test function.
func (stc *ServiceTestConfig) runBench(b *testing.B, name string, test testFunc) {
b.StopTimer()
// Channel to signal service has started
started := make(chan struct{})
// Context with cancel to stop the service
ctx, cancel := context.WithCancel(context.Background())
opts := []micro.Option{
micro.Context(ctx),
micro.AfterStart(func() error {
started <- struct{}{}
return nil
}),
}
// Create a new service per test
service, err := stc.NewService(name, opts...)
if err != nil {
b.Fatalf("failed to create service: %v", err)
}
// Register handler
if err := pb.RegisterDebugHandler(service.Server(), handler.NewHandler(service.Client())); err != nil {
b.Fatalf("failed to register handler during initial service setup: %v", err)
}
o := service.Options()
if err := o.Broker.Connect(); err != nil {
b.Fatal(err)
}
// a := new(testService)
if err := o.Server.Subscribe(o.Server.NewSubscriber(testTopic, stc.Handle)); err != nil {
b.Fatalf("[%s] Failed to register subscriber: %v", name, err)
}
if err := o.Server.Subscribe(o.Server.NewSubscriber(errorTopic, stc.HandleError)); err != nil {
b.Fatalf("[%s] Failed to register subscriber: %v", name, err)
}
b.Logf("# == [ Service ] ==================")
b.Logf("# * Server: %s", o.Server.String())
b.Logf("# * Client: %s", o.Client.String())
b.Logf("# * Transport: %s", o.Transport.String())
b.Logf("# * Broker: %s", o.Broker.String())
b.Logf("# * Registry: %s", o.Registry.String())
b.Logf("# * Auth: %s", o.Auth.String())
b.Logf("# * Cache: %s", o.Cache.String())
b.Logf("# * Runtime: %s", o.Runtime.String())
b.Logf("# ================================")
RunBenchmark(b, name, service, test, cancel, started)
}
// RunBenchmark will run benchmarks on a provided service.
//
// A test function can be provided that will be fun b.N times.
func RunBenchmark(b *testing.B, name string, service micro.Service, test testFunc,
cancel context.CancelFunc, started chan struct{}) {
b.StopTimer()
// Receive errors from routines on this channel
errChan := make(chan error, 1)
// Receive singal after service has shutdown
done := make(chan struct{})
// Start the server
go func() {
b.Logf("[%s] Starting server for benchmark", name)
if err := service.Run(); err != nil {
errChan <- errors.Wrapf(err, "[%s] Error occurred during service.Run", name)
}
done <- struct{}{}
}()
sigTerm := make(chan struct{})
// Benchmark routine
go func() {
defer func() {
b.StopTimer()
// Shutdown service
b.Logf("[%s] Shutting down", name)
cancel()
// Wait for service to be fully stopped
<-done
sigTerm <- struct{}{}
}()
// Wait for service to start
<-started
// Give the registry more time to setup
time.Sleep(time.Second)
b.Logf("[%s] Server started", name)
// Make a test call to warm the cache
for i := 0; i < 10; i++ {
if err := testRequest(context.Background(), service.Client(), name); err != nil {
errChan <- errors.Wrapf(err, "[%s] Failure during cache warmup testRequest", name)
}
}
// Check registration
services, err := service.Options().Registry.GetService(name)
if err != nil || len(services) == 0 {
errChan <- fmt.Errorf("service registration must have failed (%d services found), unable to get service: %w", len(services), err)
return
}
// Start benchmark
b.Logf("[%s] Starting benchtest", name)
b.ResetTimer()
b.StartTimer()
// Number of iterations
for i := 0; i < b.N; i++ {
test(name, service.Client(), errChan)
}
}()
// Wait for completion or catch any errors
select {
case err := <-errChan:
b.Fatal(err)
case <-sigTerm:
b.Logf("[%s] Completed benchmark", name)
}
}
// testParallel will run the test function in p parallel routines.
func testParallel(p int, test func()) {
// Waitgroup to wait for requests to finish
wg := sync.WaitGroup{}
// For concurrency
for j := 0; j < p; j++ {
wg.Add(1)
go func() {
defer wg.Done()
test()
}()
}
// Wait for test completion
wg.Wait()
}
// testRequest sends one test request.
// It calls the Debug.Health endpoint, and validates if the response returned
// contains the expected message.
func testRequest(ctx context.Context, c client.Client, name string) error {
req := c.NewRequest(
name,
"Debug.Health",
new(pb.HealthRequest),
)
rsp := new(pb.HealthResponse)
if err := c.Call(ctx, req, rsp); err != nil {
return err
}
if rsp.Status != "ok" {
return errors.New("service response: " + rsp.Status)
}
return nil
}

View File

@ -0,0 +1,33 @@
// headers is a package for internal micro global constants
package headers
const (
// Message header is a header for internal message communication.
Message = "Micro-Topic"
// Request header is a message header for internal request communication.
Request = "Micro-Service"
// Error header contains an error message.
Error = "Micro-Error"
// Endpoint header.
Endpoint = "Micro-Endpoint"
// Method header.
Method = "Micro-Method"
// ID header.
ID = "Micro-ID"
// Prefix used to prefix headers.
Prefix = "Micro-"
// Namespace header.
Namespace = "Micro-Namespace"
// Protocol header.
Protocol = "Micro-Protocol"
// Target header.
Target = "Micro-Target"
// ContentType header.
ContentType = "Content-Type"
// SpanID header.
SpanID = "Micro-Span-ID"
// TraceIDKey header.
TraceIDKey = "Micro-Trace-ID"
// Stream header.
Stream = "Micro-Stream"
)

202
transport/http_client.go Normal file
View File

@ -0,0 +1,202 @@
package transport
import (
"bufio"
"bytes"
"io"
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/pkg/errors"
log "go-micro.dev/v4/logger"
"go-micro.dev/v4/util/buf"
)
type httpTransportClient struct {
ht *httpTransport
addr string
conn net.Conn
dialOpts DialOptions
once sync.Once
sync.RWMutex
// request must be stored for response processing
req chan *http.Request
reqList []*http.Request
buff *bufio.Reader
closed bool
// local/remote ip
local string
remote string
}
func (h *httpTransportClient) Local() string {
return h.local
}
func (h *httpTransportClient) Remote() string {
return h.remote
}
func (h *httpTransportClient) Send(m *Message) error {
logger := h.ht.Options().Logger
header := make(http.Header)
for k, v := range m.Header {
header.Set(k, v)
}
b := buf.New(bytes.NewBuffer(m.Body))
defer func() {
if err := b.Close(); err != nil {
logger.Logf(log.ErrorLevel, "failed to close buffer: %v", err)
}
}()
req := &http.Request{
Method: http.MethodPost,
URL: &url.URL{
Scheme: "http",
Host: h.addr,
},
Header: header,
Body: b,
ContentLength: int64(b.Len()),
Host: h.addr,
Close: h.dialOpts.ConnClose,
}
if !h.dialOpts.Stream {
h.Lock()
if h.closed {
h.Unlock()
return io.EOF
}
h.reqList = append(h.reqList, req)
select {
case h.req <- h.reqList[0]:
h.reqList = h.reqList[1:]
default:
}
h.Unlock()
}
// set timeout if its greater than 0
if h.ht.opts.Timeout > time.Duration(0) {
if err := h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil {
return err
}
}
return req.Write(h.conn)
}
// Recv receives a message.
func (h *httpTransportClient) Recv(msg *Message) (err error) {
if msg == nil {
return errors.New("message passed in is nil")
}
var req *http.Request
if !h.dialOpts.Stream {
rc, ok := <-h.req
if !ok {
h.Lock()
if len(h.reqList) == 0 {
h.Unlock()
return io.EOF
}
rc = h.reqList[0]
h.reqList = h.reqList[1:]
h.Unlock()
}
req = rc
}
// set timeout if its greater than 0
if h.ht.opts.Timeout > time.Duration(0) {
if err = h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil {
return err
}
}
h.Lock()
defer h.Unlock()
if h.closed {
return io.EOF
}
rsp, err := http.ReadResponse(h.buff, req)
if err != nil {
return err
}
defer func() {
if err = rsp.Body.Close(); err != nil {
err = errors.Wrap(err, "failed to close body")
}
}()
b, err := io.ReadAll(rsp.Body)
if err != nil {
return err
}
if rsp.StatusCode != http.StatusOK {
return errors.New(rsp.Status + ": " + string(b))
}
msg.Body = b
if msg.Header == nil {
msg.Header = make(map[string]string, len(rsp.Header))
}
for k, v := range rsp.Header {
if len(v) > 0 {
msg.Header[k] = v[0]
} else {
msg.Header[k] = ""
}
}
return nil
}
func (h *httpTransportClient) Close() error {
if !h.dialOpts.Stream {
h.once.Do(func() {
h.Lock()
h.buff.Reset(nil)
h.closed = true
h.Unlock()
close(h.req)
})
return h.conn.Close()
}
err := h.conn.Close()
h.once.Do(func() {
h.Lock()
h.buff.Reset(nil)
h.closed = true
h.Unlock()
close(h.req)
})
return err
}

132
transport/http_listener.go Normal file
View File

@ -0,0 +1,132 @@
package transport
import (
"bufio"
"bytes"
"io"
"net"
"net/http"
"time"
log "go-micro.dev/v4/logger"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
type httpTransportListener struct {
ht *httpTransport
listener net.Listener
}
func (h *httpTransportListener) Addr() string {
return h.listener.Addr().String()
}
func (h *httpTransportListener) Close() error {
return h.listener.Close()
}
func (h *httpTransportListener) Accept(fn func(Socket)) error {
// Create handler mux
// TODO: see if we should make a plugin out of the mux
mux := http.NewServeMux()
// Register our transport handler
mux.HandleFunc("/", h.newHandler(fn))
// Get optional handlers
// TODO: This needs to be documented clearer, and examples provided
if h.ht.opts.Context != nil {
handlers, ok := h.ht.opts.Context.Value("http_handlers").(map[string]http.Handler)
if ok {
for pattern, handler := range handlers {
mux.Handle(pattern, handler)
}
}
}
// Server ONLY supports HTTP1 + H2C
srv := &http.Server{
Handler: mux,
ReadHeaderTimeout: time.Second * 5,
}
// insecure connection use h2c
if !(h.ht.opts.Secure || h.ht.opts.TLSConfig != nil) {
srv.Handler = h2c.NewHandler(mux, &http2.Server{})
}
return srv.Serve(h.listener)
}
// newHandler creates a new HTTP transport handler passed to the mux.
func (h *httpTransportListener) newHandler(serveConn func(Socket)) func(rsp http.ResponseWriter, req *http.Request) {
logger := h.ht.opts.Logger
return func(rsp http.ResponseWriter, req *http.Request) {
var (
buf *bufio.ReadWriter
con net.Conn
)
// HTTP1: read a regular request
if req.ProtoMajor == 1 {
b, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rsp, err.Error(), http.StatusInternalServerError)
return
}
req.Body = io.NopCloser(bytes.NewReader(b))
// Hijack the conn
// We also don't close the connection here, as it will be closed by
// the httpTransportSocket
hj, ok := rsp.(http.Hijacker)
if !ok {
// We're screwed
http.Error(rsp, "cannot serve conn", http.StatusInternalServerError)
return
}
conn, bufrw, err := hj.Hijack()
if err != nil {
http.Error(rsp, err.Error(), http.StatusInternalServerError)
return
}
defer func() {
if err := conn.Close(); err != nil {
logger.Logf(log.ErrorLevel, "Failed to close TCP connection: %v", err)
}
}()
buf = bufrw
con = conn
}
// Buffered reader
bufr := bufio.NewReader(req.Body)
// Save the request
ch := make(chan *http.Request, 1)
ch <- req
// Create a new transport socket
sock := &httpTransportSocket{
ht: h.ht,
w: rsp,
r: req,
rw: buf,
buf: bufr,
ch: ch,
conn: con,
local: h.Addr(),
remote: req.RemoteAddr,
closed: make(chan bool),
}
// Execute the socket
serveConn(sock)
}
}

263
transport/http_socket.go Normal file
View File

@ -0,0 +1,263 @@
package transport
import (
"bufio"
"bytes"
"io"
"net"
"net/http"
"sync"
"time"
"github.com/pkg/errors"
)
type httpTransportSocket struct {
ht *httpTransport
w http.ResponseWriter
r *http.Request
rw *bufio.ReadWriter
mtx sync.RWMutex
// the hijacked when using http 1
conn net.Conn
// for the first request
ch chan *http.Request
// h2 things
buf *bufio.Reader
// indicate if socket is closed
closed chan bool
// local/remote ip
local string
remote string
}
func (h *httpTransportSocket) Local() string {
return h.local
}
func (h *httpTransportSocket) Remote() string {
return h.remote
}
func (h *httpTransportSocket) Recv(msg *Message) error {
if msg == nil {
return errors.New("message passed in is nil")
}
if msg.Header == nil {
msg.Header = make(map[string]string, len(h.r.Header))
}
if h.r.ProtoMajor == 1 {
return h.recvHTTP1(msg)
}
return h.recvHTTP2(msg)
}
func (h *httpTransportSocket) Send(msg *Message) error {
// we need to lock to protect the write
h.mtx.RLock()
defer h.mtx.RUnlock()
if h.r.ProtoMajor == 1 {
return h.sendHTTP1(msg)
}
return h.sendHTTP2(msg)
}
func (h *httpTransportSocket) Close() error {
h.mtx.Lock()
defer h.mtx.Unlock()
select {
case <-h.closed:
return nil
default:
// Close the channel
close(h.closed)
// Close the buffer
if err := h.r.Body.Close(); err != nil {
return err
}
}
return nil
}
func (h *httpTransportSocket) error(m *Message) error {
if h.r.ProtoMajor == 1 {
rsp := &http.Response{
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(m.Body)),
Status: "500 Internal Server Error",
StatusCode: http.StatusInternalServerError,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
ContentLength: int64(len(m.Body)),
}
for k, v := range m.Header {
rsp.Header.Set(k, v)
}
return rsp.Write(h.conn)
}
return nil
}
func (h *httpTransportSocket) recvHTTP1(msg *Message) error {
// set timeout if its greater than 0
if h.ht.opts.Timeout > time.Duration(0) {
if err := h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil {
return errors.Wrap(err, "failed to set deadline")
}
}
var req *http.Request
select {
// get first request
case req = <-h.ch:
// read next request
default:
rr, err := http.ReadRequest(h.rw.Reader)
if err != nil {
return errors.Wrap(err, "failed to read request")
}
req = rr
}
// read body
b, err := io.ReadAll(req.Body)
if err != nil {
return errors.Wrap(err, "failed to read body")
}
// set body
if err := req.Body.Close(); err != nil {
return errors.Wrap(err, "failed to close body")
}
msg.Body = b
// set headers
for k, v := range req.Header {
if len(v) > 0 {
msg.Header[k] = v[0]
} else {
msg.Header[k] = ""
}
}
// return early early
return nil
}
func (h *httpTransportSocket) recvHTTP2(msg *Message) error {
// only process if the socket is open
select {
case <-h.closed:
return io.EOF
default:
}
// read streaming body
// set max buffer size
s := h.ht.opts.BuffSizeH2
if s == 0 {
s = DefaultBufSizeH2
}
buf := make([]byte, s)
// read the request body
n, err := h.buf.Read(buf)
// not an eof error
if err != nil {
return err
}
// check if we have data
if n > 0 {
msg.Body = buf[:n]
}
// set headers
for k, v := range h.r.Header {
if len(v) > 0 {
msg.Header[k] = v[0]
} else {
msg.Header[k] = ""
}
}
// set path
msg.Header[":path"] = h.r.URL.Path
return nil
}
func (h *httpTransportSocket) sendHTTP1(msg *Message) error {
// make copy of header
hdr := make(http.Header)
for k, v := range h.r.Header {
hdr[k] = v
}
rsp := &http.Response{
Header: hdr,
Body: io.NopCloser(bytes.NewReader(msg.Body)),
Status: "200 OK",
StatusCode: http.StatusOK,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
ContentLength: int64(len(msg.Body)),
}
for k, v := range msg.Header {
rsp.Header.Set(k, v)
}
// set timeout if its greater than 0
if h.ht.opts.Timeout > time.Duration(0) {
if err := h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil {
return err
}
}
return rsp.Write(h.conn)
}
func (h *httpTransportSocket) sendHTTP2(msg *Message) error {
// only process if the socket is open
select {
case <-h.closed:
return io.EOF
default:
}
// set headers
for k, v := range msg.Header {
h.w.Header().Set(k, v)
}
// write request
_, err := h.w.Write(msg.Body)
// flush the trailers
h.w.(http.Flusher).Flush()
return err
}

View File

@ -2,529 +2,41 @@ package transport
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"errors"
"io"
"net" "net"
"net/http" "net/http"
"net/url"
"sync"
"time"
"go-micro.dev/v4/logger"
maddr "go-micro.dev/v4/util/addr" maddr "go-micro.dev/v4/util/addr"
"go-micro.dev/v4/util/buf"
mnet "go-micro.dev/v4/util/net" mnet "go-micro.dev/v4/util/net"
mls "go-micro.dev/v4/util/tls" mls "go-micro.dev/v4/util/tls"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
) )
type httpTransport struct { type httpTransport struct {
opts Options opts Options
} }
type httpTransportClient struct { func NewHTTPTransport(opts ...Option) *httpTransport {
ht *httpTransport options := Options{
addr string BuffSizeH2: DefaultBufSizeH2,
conn net.Conn Logger: logger.DefaultLogger,
dialOpts DialOptions }
once sync.Once
sync.RWMutex for _, o := range opts {
o(&options)
}
// request must be stored for response processing return &httpTransport{opts: options}
r chan *http.Request
bl []*http.Request
buff *bufio.Reader
closed bool
// local/remote ip
local string
remote string
} }
type httpTransportSocket struct { func (h *httpTransport) Init(opts ...Option) error {
ht *httpTransport for _, o := range opts {
w http.ResponseWriter o(&h.opts)
r *http.Request
rw *bufio.ReadWriter
mtx sync.RWMutex
// the hijacked when using http 1
conn net.Conn
// for the first request
ch chan *http.Request
// h2 things
buf *bufio.Reader
// indicate if socket is closed
closed chan bool
// local/remote ip
local string
remote string
}
type httpTransportListener struct {
ht *httpTransport
listener net.Listener
}
func (h *httpTransportClient) Local() string {
return h.local
}
func (h *httpTransportClient) Remote() string {
return h.remote
}
func (h *httpTransportClient) Send(m *Message) error {
header := make(http.Header)
for k, v := range m.Header {
header.Set(k, v)
}
b := buf.New(bytes.NewBuffer(m.Body))
defer b.Close()
req := &http.Request{
Method: http.MethodPost,
URL: &url.URL{
Scheme: "http",
Host: h.addr,
},
Header: header,
Body: b,
ContentLength: int64(b.Len()),
Host: h.addr,
}
if !h.dialOpts.Stream {
h.Lock()
if h.closed {
h.Unlock()
return io.EOF
}
h.bl = append(h.bl, req)
select {
case h.r <- h.bl[0]:
h.bl = h.bl[1:]
default:
}
h.Unlock()
}
// set timeout if its greater than 0
if h.ht.opts.Timeout > time.Duration(0) {
if err := h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil {
return err
}
}
return req.Write(h.conn)
}
// Recv receives a message.
func (h *httpTransportClient) Recv(msg *Message) error {
if msg == nil {
return errors.New("message passed in is nil")
}
var req *http.Request
if !h.dialOpts.Stream {
rc, ok := <-h.r
if !ok {
h.Lock()
if len(h.bl) == 0 {
h.Unlock()
return io.EOF
}
rc = h.bl[0]
h.bl = h.bl[1:]
h.Unlock()
}
req = rc
}
// set timeout if its greater than 0
if h.ht.opts.Timeout > time.Duration(0) {
if err := h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil {
return err
}
}
h.Lock()
defer h.Unlock()
if h.closed {
return io.EOF
}
rsp, err := http.ReadResponse(h.buff, req)
if err != nil {
return err
}
defer rsp.Body.Close()
b, err := io.ReadAll(rsp.Body)
if err != nil {
return err
}
if rsp.StatusCode != http.StatusOK {
return errors.New(rsp.Status + ": " + string(b))
}
msg.Body = b
if msg.Header == nil {
msg.Header = make(map[string]string, len(rsp.Header))
}
for k, v := range rsp.Header {
if len(v) > 0 {
msg.Header[k] = v[0]
} else {
msg.Header[k] = ""
}
} }
return nil return nil
} }
func (h *httpTransportClient) Close() error {
if !h.dialOpts.Stream {
h.once.Do(func() {
h.Lock()
h.buff.Reset(nil)
h.closed = true
h.Unlock()
close(h.r)
})
return h.conn.Close()
}
err := h.conn.Close()
h.once.Do(func() {
h.Lock()
h.buff.Reset(nil)
h.closed = true
h.Unlock()
close(h.r)
})
return err
}
func (h *httpTransportSocket) Local() string {
return h.local
}
func (h *httpTransportSocket) Remote() string {
return h.remote
}
func (h *httpTransportSocket) Recv(msg *Message) error {
if msg == nil {
return errors.New("message passed in is nil")
}
if msg.Header == nil {
msg.Header = make(map[string]string, len(h.r.Header))
}
// process http 1
if h.r.ProtoMajor == 1 {
// set timeout if its greater than 0
if h.ht.opts.Timeout > time.Duration(0) {
h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout))
}
var r *http.Request
select {
// get first request
case r = <-h.ch:
// read next request
default:
rr, err := http.ReadRequest(h.rw.Reader)
if err != nil {
return err
}
r = rr
}
// read body
b, err := io.ReadAll(r.Body)
if err != nil {
return err
}
// set body
r.Body.Close()
msg.Body = b
// set headers
for k, v := range r.Header {
if len(v) > 0 {
msg.Header[k] = v[0]
} else {
msg.Header[k] = ""
}
}
// return early early
return nil
}
// only process if the socket is open
select {
case <-h.closed:
return io.EOF
default:
// no op
}
// processing http2 request
// read streaming body
// set max buffer size
// TODO: adjustable buffer size
buf := make([]byte, 4*1024*1024)
// read the request body
n, err := h.buf.Read(buf)
// not an eof error
if err != nil {
return err
}
// check if we have data
if n > 0 {
msg.Body = buf[:n]
}
// set headers
for k, v := range h.r.Header {
if len(v) > 0 {
msg.Header[k] = v[0]
} else {
msg.Header[k] = ""
}
}
// set path
msg.Header[":path"] = h.r.URL.Path
return nil
}
func (h *httpTransportSocket) Send(msg *Message) error {
if h.r.ProtoMajor == 1 {
// make copy of header
hdr := make(http.Header)
for k, v := range h.r.Header {
hdr[k] = v
}
rsp := &http.Response{
Header: hdr,
Body: io.NopCloser(bytes.NewReader(msg.Body)),
Status: "200 OK",
StatusCode: http.StatusOK,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
ContentLength: int64(len(msg.Body)),
}
for k, v := range msg.Header {
rsp.Header.Set(k, v)
}
// set timeout if its greater than 0
if h.ht.opts.Timeout > time.Duration(0) {
h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout))
}
return rsp.Write(h.conn)
}
// only process if the socket is open
select {
case <-h.closed:
return io.EOF
default:
// no op
}
// we need to lock to protect the write
h.mtx.RLock()
defer h.mtx.RUnlock()
// set headers
for k, v := range msg.Header {
h.w.Header().Set(k, v)
}
// write request
_, err := h.w.Write(msg.Body)
// flush the trailers
h.w.(http.Flusher).Flush()
return err
}
func (h *httpTransportSocket) error(m *Message) error {
if h.r.ProtoMajor == 1 {
rsp := &http.Response{
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(m.Body)),
Status: "500 Internal Server Error",
StatusCode: http.StatusInternalServerError,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
ContentLength: int64(len(m.Body)),
}
for k, v := range m.Header {
rsp.Header.Set(k, v)
}
return rsp.Write(h.conn)
}
return nil
}
func (h *httpTransportSocket) Close() error {
h.mtx.Lock()
defer h.mtx.Unlock()
select {
case <-h.closed:
return nil
default:
// close the channel
close(h.closed)
// close the buffer
h.r.Body.Close()
// close the connection
if h.r.ProtoMajor == 1 {
return h.conn.Close()
}
}
return nil
}
func (h *httpTransportListener) Addr() string {
return h.listener.Addr().String()
}
func (h *httpTransportListener) Close() error {
return h.listener.Close()
}
func (h *httpTransportListener) Accept(fn func(Socket)) error {
// create handler mux
mux := http.NewServeMux()
// register our transport handler
mux.HandleFunc("/", func(rsp http.ResponseWriter, req *http.Request) {
var buf *bufio.ReadWriter
var con net.Conn
// read a regular request
if req.ProtoMajor == 1 {
b, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rsp, err.Error(), http.StatusInternalServerError)
return
}
req.Body = io.NopCloser(bytes.NewReader(b))
// hijack the conn
hj, ok := rsp.(http.Hijacker)
if !ok {
// we're screwed
http.Error(rsp, "cannot serve conn", http.StatusInternalServerError)
return
}
conn, bufrw, err := hj.Hijack()
if err != nil {
http.Error(rsp, err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()
buf = bufrw
con = conn
}
// buffered reader
bufr := bufio.NewReader(req.Body)
// save the request
ch := make(chan *http.Request, 1)
ch <- req
// create a new transport socket
sock := &httpTransportSocket{
ht: h.ht,
w: rsp,
r: req,
rw: buf,
buf: bufr,
ch: ch,
conn: con,
local: h.Addr(),
remote: req.RemoteAddr,
closed: make(chan bool),
}
// execute the socket
fn(sock)
})
// get optional handlers
if h.ht.opts.Context != nil {
handlers, ok := h.ht.opts.Context.Value("http_handlers").(map[string]http.Handler)
if ok {
for pattern, handler := range handlers {
mux.Handle(pattern, handler)
}
}
}
// default http2 server
srv := &http.Server{
Handler: mux,
}
// insecure connection use h2c
if !(h.ht.opts.Secure || h.ht.opts.TLSConfig != nil) {
srv.Handler = h2c.NewHandler(mux, &http2.Server{})
}
// begin serving
return srv.Serve(h.listener)
}
func (h *httpTransport) Dial(addr string, opts ...DialOption) (Client, error) { func (h *httpTransport) Dial(addr string, opts ...DialOption) (Client, error) {
dopts := DialOptions{ dopts := DialOptions{
Timeout: DefaultDialTimeout, Timeout: DefaultDialTimeout,
@ -539,12 +51,11 @@ func (h *httpTransport) Dial(addr string, opts ...DialOption) (Client, error) {
err error err error
) )
// TODO: support dial option here rather than using internal config
if h.opts.Secure || h.opts.TLSConfig != nil { if h.opts.Secure || h.opts.TLSConfig != nil {
config := h.opts.TLSConfig config := h.opts.TLSConfig
if config == nil { if config == nil {
config = &tls.Config{ config = &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: dopts.InsecureSkipVerify,
} }
} }
@ -569,7 +80,7 @@ func (h *httpTransport) Dial(addr string, opts ...DialOption) (Client, error) {
conn: conn, conn: conn,
buff: bufio.NewReader(conn), buff: bufio.NewReader(conn),
dialOpts: dopts, dialOpts: dopts,
r: make(chan *http.Request, 100), req: make(chan *http.Request, 100),
local: conn.LocalAddr().String(), local: conn.LocalAddr().String(),
remote: conn.RemoteAddr().String(), remote: conn.RemoteAddr().String(),
}, nil }, nil
@ -586,45 +97,58 @@ func (h *httpTransport) Listen(addr string, opts ...ListenOption) (Listener, err
err error err error
) )
if listener := getNetListener(&options); listener != nil { switch listener := getNetListener(&options); {
fn := func(addr string) (net.Listener, error) { // Extracted listener from context
case listener != nil:
getList := func(addr string) (net.Listener, error) {
return listener, nil return listener, nil
} }
list, err = mnet.Listen(addr, fn) list, err = mnet.Listen(addr, getList)
} else if h.opts.Secure || h.opts.TLSConfig != nil {
// Needs to create self signed certificate
case h.opts.Secure || h.opts.TLSConfig != nil:
config := h.opts.TLSConfig config := h.opts.TLSConfig
fn := func(addr string) (net.Listener, error) { getList := func(addr string) (net.Listener, error) {
if config == nil { if config != nil {
hosts := []string{addr} return tls.Listen("tcp", addr, config)
// check if its a valid host:port
if host, _, err := net.SplitHostPort(addr); err == nil {
if len(host) == 0 {
hosts = maddr.IPs()
} else {
hosts = []string{host}
}
}
// generate a certificate
cert, err := mls.Certificate(hosts...)
if err != nil {
return nil, err
}
config = &tls.Config{Certificates: []tls.Certificate{cert}}
} }
hosts := []string{addr}
// check if its a valid host:port
if host, _, err := net.SplitHostPort(addr); err == nil {
if len(host) == 0 {
hosts = maddr.IPs()
} else {
hosts = []string{host}
}
}
// generate a certificate
cert, err := mls.Certificate(hosts...)
if err != nil {
return nil, err
}
config = &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
return tls.Listen("tcp", addr, config) return tls.Listen("tcp", addr, config)
} }
list, err = mnet.Listen(addr, fn) list, err = mnet.Listen(addr, getList)
} else {
fn := func(addr string) (net.Listener, error) { // Create new basic net listener
default:
getList := func(addr string) (net.Listener, error) {
return net.Listen("tcp", addr) return net.Listen("tcp", addr)
} }
list, err = mnet.Listen(addr, fn) list, err = mnet.Listen(addr, getList)
} }
if err != nil { if err != nil {
@ -637,14 +161,6 @@ func (h *httpTransport) Listen(addr string, opts ...ListenOption) (Listener, err
}, nil }, nil
} }
func (h *httpTransport) Init(opts ...Option) error {
for _, o := range opts {
o(&h.opts)
}
return nil
}
func (h *httpTransport) Options() Options { func (h *httpTransport) Options() Options {
return h.opts return h.opts
} }
@ -652,12 +168,3 @@ func (h *httpTransport) Options() Options {
func (h *httpTransport) String() string { func (h *httpTransport) String() string {
return "http" return "http"
} }
func NewHTTPTransport(opts ...Option) *httpTransport {
var options Options
for _, o := range opts {
o(&options)
}
return &httpTransport{opts: options}
}

View File

@ -10,6 +10,10 @@ import (
"go-micro.dev/v4/logger" "go-micro.dev/v4/logger"
) )
var (
DefaultBufSizeH2 = 4 * 1024 * 1024
)
type Options struct { type Options struct {
// Addrs is the list of intermediary addresses to connect to // Addrs is the list of intermediary addresses to connect to
Addrs []string Addrs []string
@ -30,6 +34,8 @@ type Options struct {
Context context.Context Context context.Context
// Logger is the underline logger // Logger is the underline logger
Logger logger.Logger Logger logger.Logger
// BuffSizeH2 is the HTTP2 buffer size
BuffSizeH2 int
} }
type DialOptions struct { type DialOptions struct {
@ -38,6 +44,10 @@ type DialOptions struct {
Stream bool Stream bool
// Timeout for dialing // Timeout for dialing
Timeout time.Duration Timeout time.Duration
// ConnClose sets the Connection header to close
ConnClose bool
// InsecureSkipVerify skip TLS verification.
InsecureSkipVerify bool
// TODO: add tls options when dialing // TODO: add tls options when dialing
// Currently set in global options // Currently set in global options
@ -106,22 +116,46 @@ func WithTimeout(d time.Duration) DialOption {
} }
} }
// WithLogger sets the underline logger. // WithConnClose sets the Connection header to close.
func WithLogger(l logger.Logger) Option { func WithConnClose() DialOption {
return func(o *DialOptions) {
o.ConnClose = true
}
}
func WithInsecureSkipVerify(b bool) DialOption {
return func(o *DialOptions) {
o.InsecureSkipVerify = b
}
}
// Logger sets the underline logger.
func Logger(l logger.Logger) Option {
return func(o *Options) { return func(o *Options) {
o.Logger = l o.Logger = l
} }
} }
// BuffSizeH2 sets the HTTP2 buffer size.
// Default is 4 * 1024 * 1024.
func BuffSizeH2(size int) Option {
return func(o *Options) {
o.BuffSizeH2 = size
}
}
// InsecureSkipVerify sets the TLS options to skip verification.
// NetListener Set net.Listener for httpTransport. // NetListener Set net.Listener for httpTransport.
func NetListener(customListener net.Listener) ListenOption { func NetListener(customListener net.Listener) ListenOption {
return func(o *ListenOptions) { return func(o *ListenOptions) {
if customListener == nil { if customListener == nil {
return return
} }
if o.Context == nil { if o.Context == nil {
o.Context = context.TODO() o.Context = context.TODO()
} }
o.Context = context.WithValue(o.Context, netListener{}, customListener) o.Context = context.WithValue(o.Context, netListener{}, customListener)
} }
} }

View File

@ -1,55 +1,30 @@
// addr provides functions to retrieve local IP addresses from device interfaces.
package addr package addr
import ( import (
"fmt"
"net" "net"
"github.com/pkg/errors"
) )
var ( var (
privateBlocks []*net.IPNet // ErrIPNotFound no IP address found, and explicit IP not provided.
ErrIPNotFound = errors.New("no IP address found, and explicit IP not provided")
) )
func init() { // IsLocal checks whether an IP belongs to one of the device's interfaces.
for _, b := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "100.64.0.0/10", "fd00::/8"} {
if _, block, err := net.ParseCIDR(b); err == nil {
privateBlocks = append(privateBlocks, block)
}
}
}
// AppendPrivateBlocks append private network blocks.
func AppendPrivateBlocks(bs ...string) {
for _, b := range bs {
if _, block, err := net.ParseCIDR(b); err == nil {
privateBlocks = append(privateBlocks, block)
}
}
}
func isPrivateIP(ipAddr string) bool {
ip := net.ParseIP(ipAddr)
for _, priv := range privateBlocks {
if priv.Contains(ip) {
return true
}
}
return false
}
// IsLocal tells us whether an ip is local.
func IsLocal(addr string) bool { func IsLocal(addr string) bool {
// extract the host // Extract the host
host, _, err := net.SplitHostPort(addr) host, _, err := net.SplitHostPort(addr)
if err == nil { if err == nil {
addr = host addr = host
} }
// check if its localhost
if addr == "localhost" { if addr == "localhost" {
return true return true
} }
// check against all local ips // Check against all local ips
for _, ip := range IPs() { for _, ip := range IPs() {
if addr == ip { if addr == ip {
return true return true
@ -59,79 +34,53 @@ func IsLocal(addr string) bool {
return false return false
} }
// Extract returns a real ip. // Extract returns a valid IP address. If the address provided is a valid
// address, it will be returned directly. Otherwise the available interfaces
// be itterated over to find an IP address, prefferably private.
func Extract(addr string) (string, error) { func Extract(addr string) (string, error) {
// if addr specified then its returned // if addr is already specified then it's directly returned
if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]" && addr != "::") { if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]" && addr != "::") {
return addr, nil return addr, nil
} }
var (
addrs []net.Addr
loAddrs []net.Addr
)
ifaces, err := net.Interfaces() ifaces, err := net.Interfaces()
if err != nil { if err != nil {
return "", fmt.Errorf("Failed to get interfaces! Err: %v", err) return "", errors.Wrap(err, "failed to get interfaces")
} }
var addrs []net.Addr
var loAddrs []net.Addr
for _, iface := range ifaces { for _, iface := range ifaces {
ifaceAddrs, err := iface.Addrs() ifaceAddrs, err := iface.Addrs()
if err != nil { if err != nil {
// ignore error, interface can disappear from system // ignore error, interface can disappear from system
continue continue
} }
if iface.Flags&net.FlagLoopback != 0 { if iface.Flags&net.FlagLoopback != 0 {
loAddrs = append(loAddrs, ifaceAddrs...) loAddrs = append(loAddrs, ifaceAddrs...)
continue continue
} }
addrs = append(addrs, ifaceAddrs...) addrs = append(addrs, ifaceAddrs...)
} }
// Add loopback addresses to the end of the list
addrs = append(addrs, loAddrs...) addrs = append(addrs, loAddrs...)
var ipAddr string // Try to find private IP in list, public IP otherwise
var publicIP string ip, err := findIP(addrs)
if err != nil {
for _, rawAddr := range addrs { return "", err
var ip net.IP
switch addr := rawAddr.(type) {
case *net.IPAddr:
ip = addr.IP
case *net.IPNet:
ip = addr.IP
default:
continue
}
if !isPrivateIP(ip.String()) {
publicIP = ip.String()
continue
}
ipAddr = ip.String()
break
} }
// return private ip return ip.String(), nil
if len(ipAddr) > 0 {
a := net.ParseIP(ipAddr)
if a == nil {
return "", fmt.Errorf("ip addr %s is invalid", ipAddr)
}
return a.String(), nil
}
// return public or virtual ip
if len(publicIP) > 0 {
a := net.ParseIP(publicIP)
if a == nil {
return "", fmt.Errorf("ip addr %s is invalid", publicIP)
}
return a.String(), nil
}
return "", fmt.Errorf("No IP address found, and explicit IP not provided")
} }
// IPs returns all known ips. // IPs returns all available interface IP addresses.
func IPs() []string { func IPs() []string {
ifaces, err := net.Interfaces() ifaces, err := net.Interfaces()
if err != nil { if err != nil {
@ -159,17 +108,42 @@ func IPs() []string {
continue continue
} }
// dont skip ipv6 addrs
/*
ip = ip.To4()
if ip == nil {
continue
}
*/
ipAddrs = append(ipAddrs, ip.String()) ipAddrs = append(ipAddrs, ip.String())
} }
} }
return ipAddrs return ipAddrs
} }
// findIP will return the first private IP available in the list,
// if no private IP is available it will return a public IP if present.
func findIP(addresses []net.Addr) (net.IP, error) {
var publicIP net.IP
for _, rawAddr := range addresses {
var ip net.IP
switch addr := rawAddr.(type) {
case *net.IPAddr:
ip = addr.IP
case *net.IPNet:
ip = addr.IP
default:
continue
}
if !ip.IsPrivate() {
publicIP = ip
continue
}
// Return private IP if available
return ip, nil
}
// Return public or virtual IP
if len(publicIP) > 0 {
return publicIP, nil
}
return nil, ErrIPNotFound
}

View File

@ -54,24 +54,3 @@ func TestExtractor(t *testing.T) {
} }
} }
} }
func TestAppendPrivateBlocks(t *testing.T) {
tests := []struct {
addr string
expect bool
}{
{addr: "9.134.71.34", expect: true},
{addr: "8.10.110.34", expect: false}, // not in private blocks
}
AppendPrivateBlocks("9.134.0.0/16")
for _, test := range tests {
t.Run(test.addr, func(t *testing.T) {
res := isPrivateIP(test.addr)
if res != test.expect {
t.Fatalf("expected %t got %t", test.expect, res)
}
})
}
}

View File

@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"go-micro.dev/v4/transport" "go-micro.dev/v4/transport"
) )
@ -34,14 +35,21 @@ func newPool(options Options) *pool {
func (p *pool) Close() error { func (p *pool) Close() error {
p.Lock() p.Lock()
defer p.Unlock()
var err error
for k, c := range p.conns { for k, c := range p.conns {
for _, conn := range c { for _, conn := range c {
conn.Client.Close() if nerr := conn.Client.Close(); nerr != nil {
err = nerr
}
} }
delete(p.conns, k) delete(p.conns, k)
} }
p.Unlock()
return nil return err
} }
// NoOp the Close since we manage it. // NoOp the Close since we manage it.
@ -61,20 +69,24 @@ func (p *pool) Get(addr string, opts ...transport.DialOption) (Conn, error) {
p.Lock() p.Lock()
conns := p.conns[addr] conns := p.conns[addr]
// while we have conns check age and then return one // While we have conns check age and then return one
// otherwise we'll create a new conn // otherwise we'll create a new conn
for len(conns) > 0 { for len(conns) > 0 {
conn := conns[len(conns)-1] conn := conns[len(conns)-1]
conns = conns[:len(conns)-1] conns = conns[:len(conns)-1]
p.conns[addr] = conns p.conns[addr] = conns
// if conn is old kill it and move on // If conn is old kill it and move on
if d := time.Since(conn.Created()); d > p.ttl { if d := time.Since(conn.Created()); d > p.ttl {
conn.Client.Close() if err := conn.Client.Close(); err != nil {
p.Unlock()
return nil, err
}
continue continue
} }
// we got a good conn, lets unlock and return it // We got a good conn, lets unlock and return it
p.Unlock() p.Unlock()
return conn, nil return conn, nil
@ -87,6 +99,7 @@ func (p *pool) Get(addr string, opts ...transport.DialOption) (Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &poolConn{ return &poolConn{
Client: c, Client: c,
id: uuid.New().String(), id: uuid.New().String(),
@ -102,13 +115,14 @@ func (p *pool) Release(conn Conn, err error) error {
// otherwise put it back for reuse // otherwise put it back for reuse
p.Lock() p.Lock()
defer p.Unlock()
conns := p.conns[conn.Remote()] conns := p.conns[conn.Remote()]
if len(conns) >= p.size { if len(conns) >= p.size {
p.Unlock()
return conn.(*poolConn).Client.Close() return conn.(*poolConn).Client.Close()
} }
p.conns[conn.Remote()] = append(conns, conn.(*poolConn)) p.conns[conn.Remote()] = append(conns, conn.(*poolConn))
p.Unlock()
return nil return nil
} }

View File

@ -13,10 +13,11 @@ type Pool interface {
Close() error Close() error
// Get a connection // Get a connection
Get(addr string, opts ...transport.DialOption) (Conn, error) Get(addr string, opts ...transport.DialOption) (Conn, error)
// Releaes the connection // Release the connection
Release(c Conn, status error) error Release(c Conn, status error) error
} }
// Conn interface represents a pool connection.
type Conn interface { type Conn interface {
// unique id of connection // unique id of connection
Id() string Id() string
@ -26,10 +27,12 @@ type Conn interface {
transport.Client transport.Client
} }
// NewPool will return a new pool object.
func NewPool(opts ...Option) Pool { func NewPool(opts ...Option) Pool {
var options Options var options Options
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
} }
return newPool(options) return newPool(options)
} }

View File

@ -5,7 +5,7 @@ import (
) )
var ( var (
// mock registry data. // Data is a set of mock registry data.
Data = map[string][]*registry.Service{ Data = map[string][]*registry.Service{
"foo": { "foo": {
{ {
@ -45,3 +45,14 @@ var (
}, },
} }
) )
// EmptyChannel will empty out a error channel by checking if an error is
// present, and if so return the error.
func EmptyChannel(c chan error) error {
select {
case err := <-c:
return err
default:
return nil
}
}

View File

@ -10,6 +10,7 @@ import (
"go-micro.dev/v4/debug/trace" "go-micro.dev/v4/debug/trace"
"go-micro.dev/v4/metadata" "go-micro.dev/v4/metadata"
"go-micro.dev/v4/server" "go-micro.dev/v4/server"
"go-micro.dev/v4/transport/headers"
) )
type fromServiceWrapper struct { type fromServiceWrapper struct {
@ -19,10 +20,6 @@ type fromServiceWrapper struct {
headers metadata.Metadata headers metadata.Metadata
} }
var (
HeaderPrefix = "Micro-"
)
func (f *fromServiceWrapper) setHeaders(ctx context.Context) context.Context { func (f *fromServiceWrapper) setHeaders(ctx context.Context) context.Context {
// don't overwrite keys // don't overwrite keys
return metadata.MergeContext(ctx, f.headers, false) return metadata.MergeContext(ctx, f.headers, false)
@ -48,7 +45,7 @@ func FromService(name string, c client.Client) client.Client {
return &fromServiceWrapper{ return &fromServiceWrapper{
c, c,
metadata.Metadata{ metadata.Metadata{
HeaderPrefix + "From-Service": name, headers.Prefix + "From-Service": name,
}, },
} }
} }
@ -159,8 +156,8 @@ func (a *authWrapper) Call(ctx context.Context, req client.Request, rsp interfac
} }
// set the namespace header if it has not been set (e.g. on a service to service request) // set the namespace header if it has not been set (e.g. on a service to service request)
if _, ok := metadata.Get(ctx, "Micro-Namespace"); !ok { if _, ok := metadata.Get(ctx, headers.Namespace); !ok {
ctx = metadata.Set(ctx, "Micro-Namespace", aa.Options().Namespace) ctx = metadata.Set(ctx, headers.Namespace, aa.Options().Namespace)
} }
// check to see if we have a valid access token // check to see if we have a valid access token