mirror of
				https://github.com/go-micro/go-micro.git
				synced 2025-10-30 23:27:41 +02:00 
			
		
		
		
	feat: add test framework & refactor RPC server (#2579)
Co-authored-by: Rene Jochum <rene@jochum.dev>
This commit is contained in:
		
							
								
								
									
										4
									
								
								.github/workflows/tests.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/tests.yaml
									
									
									
									
										vendored
									
									
								
							| @@ -39,7 +39,7 @@ jobs: | ||||
|           go get -v -t -d ./... | ||||
|       - name: Run tests | ||||
|         id: tests | ||||
|         run: richgo test -v -race -cover ./... | ||||
|         run: richgo test -v -race -cover -bench=. ./... | ||||
|         env: | ||||
|           IN_TRAVIS_CI: yes | ||||
|           RICHGO_FORCE_COLOR: 1 | ||||
| @@ -60,6 +60,6 @@ jobs: | ||||
|           go get -v -t -d ./... | ||||
|       - name: Run tests | ||||
|         id: tests | ||||
|         run: go test -v -race -cover -json ./... | tparse -notests -format=markdown >> $GITHUB_STEP_SUMMARY | ||||
|         run: go test -v -race -cover -json -bench=. ./... | tparse -notests -format=markdown >> $GITHUB_STEP_SUMMARY | ||||
|         env: | ||||
|           IN_TRAVIS_CI: yes | ||||
|   | ||||
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -35,3 +35,7 @@ _cgo_export.* | ||||
| *~ | ||||
| *.swp | ||||
| *.swo | ||||
|  | ||||
| # go work files | ||||
| go.work | ||||
| go.work.sum | ||||
|   | ||||
| @@ -57,6 +57,11 @@ output: | ||||
|  | ||||
| # all available settings of specific linters | ||||
| linters-settings: | ||||
|   wsl: | ||||
|     allow-cuddle-with-calls: ["Lock", "RLock", "defer"] | ||||
|   funlen: | ||||
|     lines: 80 | ||||
|     statements: 60 | ||||
|   varnamelen: | ||||
|     # The longest distance, in source lines, that is being considered a "small scope". | ||||
|     # Variables used in at most this many lines will be ignored. | ||||
| @@ -184,6 +189,7 @@ linters: | ||||
|     - makezero | ||||
|     - gofumpt | ||||
|     - nlreturn | ||||
|     - thelper | ||||
|  | ||||
|     # Can be considered to be enabled | ||||
|     - gochecknoinits | ||||
| @@ -197,6 +203,9 @@ linters: | ||||
|     - exhaustruct | ||||
|     - containedctx | ||||
|     - godox | ||||
|     - forcetypeassert | ||||
|     - gci | ||||
|     - lll | ||||
|  | ||||
| issues: | ||||
|   # List of regexps of issue texts to exclude, empty list by default. | ||||
|   | ||||
| @@ -14,7 +14,7 @@ skipStyle: | ||||
|   foreground: lightBlack | ||||
| passPackageStyle: | ||||
|   foreground: green | ||||
|   hide: true | ||||
|   hide: false | ||||
| failPackageStyle: | ||||
|   bold: true | ||||
|   foreground: "#821515" | ||||
|   | ||||
| @@ -20,6 +20,7 @@ import ( | ||||
| 	merr "go-micro.dev/v4/errors" | ||||
| 	"go-micro.dev/v4/registry" | ||||
| 	"go-micro.dev/v4/registry/cache" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| 	maddr "go-micro.dev/v4/util/addr" | ||||
| 	mnet "go-micro.dev/v4/util/net" | ||||
| 	mls "go-micro.dev/v4/util/tls" | ||||
| @@ -313,7 +314,7 @@ func (h *httpBroker) ServeHTTP(w http.ResponseWriter, req *http.Request) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	topic := m.Header["Micro-Topic"] | ||||
| 	topic := m.Header[headers.Message] | ||||
| 	// delete(m.Header, ":topic") | ||||
|  | ||||
| 	if len(topic) == 0 { | ||||
| @@ -517,7 +518,7 @@ func (h *httpBroker) Publish(topic string, msg *Message, opts ...PublishOption) | ||||
| 		m.Header[k] = v | ||||
| 	} | ||||
|  | ||||
| 	m.Header["Micro-Topic"] = topic | ||||
| 	m.Header[headers.Message] = topic | ||||
|  | ||||
| 	// encode the message | ||||
| 	b, err := h.opts.Codec.Marshal(m) | ||||
|   | ||||
| @@ -8,7 +8,9 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	cache "github.com/patrickmn/go-cache" | ||||
|  | ||||
| 	"go-micro.dev/v4/metadata" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| ) | ||||
|  | ||||
| // NewCache returns an initialized cache. | ||||
| @@ -38,6 +40,7 @@ func (c *Cache) List() map[string]string { | ||||
| 	items := c.cache.Items() | ||||
|  | ||||
| 	rsp := make(map[string]string, len(items)) | ||||
|  | ||||
| 	for k, v := range items { | ||||
| 		bytes, _ := json.Marshal(v.Object) | ||||
| 		rsp[k] = string(bytes) | ||||
| @@ -48,7 +51,7 @@ func (c *Cache) List() map[string]string { | ||||
|  | ||||
| // key returns a hash for the context and request. | ||||
| func key(ctx context.Context, req *Request) string { | ||||
| 	ns, _ := metadata.Get(ctx, "Micro-Namespace") | ||||
| 	ns, _ := metadata.Get(ctx, headers.Namespace) | ||||
|  | ||||
| 	bytes, _ := json.Marshal(map[string]interface{}{ | ||||
| 		"namespace": ns, | ||||
| @@ -62,5 +65,6 @@ func key(ctx context.Context, req *Request) string { | ||||
|  | ||||
| 	h := fnv.New64() | ||||
| 	h.Write(bytes) | ||||
|  | ||||
| 	return fmt.Sprintf("%x", h.Sum(nil)) | ||||
| } | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"go-micro.dev/v4/metadata" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| ) | ||||
|  | ||||
| func TestCache(t *testing.T) { | ||||
| @@ -65,7 +66,7 @@ func TestCacheKey(t *testing.T) { | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("DifferentMetadata", func(t *testing.T) { | ||||
| 		mdCtx := metadata.Set(context.TODO(), "Micro-Namespace", "bar") | ||||
| 		mdCtx := metadata.Set(context.TODO(), headers.Namespace, "bar") | ||||
| 		key1 := key(mdCtx, &req1) | ||||
| 		key2 := key(ctx, &req1) | ||||
|  | ||||
|   | ||||
| @@ -3,11 +3,17 @@ package client | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"time" | ||||
|  | ||||
| 	"go-micro.dev/v4/codec" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// NewClient returns a new client. | ||||
| 	NewClient func(...Option) Client = newRPCClient | ||||
| 	// DefaultClient is a default client to use out of the box. | ||||
| 	DefaultClient Client = newRPCClient() | ||||
| ) | ||||
|  | ||||
| // Client is the interface used to make requests to services. | ||||
| // It supports Request/Response via Transport and Publishing via the Broker. | ||||
| // It also supports bidirectional streaming of requests. | ||||
| @@ -102,26 +108,6 @@ type MessageOption func(*MessageOptions) | ||||
| // RequestOption used by NewRequest. | ||||
| type RequestOption func(*RequestOptions) | ||||
|  | ||||
| var ( | ||||
| 	// DefaultClient is a default client to use out of the box. | ||||
| 	DefaultClient Client = newRpcClient() | ||||
| 	// DefaultBackoff is the default backoff function for retries. | ||||
| 	DefaultBackoff = exponentialBackoff | ||||
| 	// DefaultRetry is the default check-for-retry function for retries. | ||||
| 	DefaultRetry = RetryOnError | ||||
| 	// DefaultRetries is the default number of times a request is tried. | ||||
| 	DefaultRetries = 1 | ||||
| 	// DefaultRequestTimeout is the default request timeout. | ||||
| 	DefaultRequestTimeout = time.Second * 5 | ||||
| 	// DefaultPoolSize sets the connection pool size. | ||||
| 	DefaultPoolSize = 100 | ||||
| 	// DefaultPoolTTL sets the connection pool ttl. | ||||
| 	DefaultPoolTTL = time.Minute | ||||
|  | ||||
| 	// NewClient returns a new client. | ||||
| 	NewClient func(...Option) Client = newRpcClient | ||||
| ) | ||||
|  | ||||
| // Makes a synchronous call to a service using the default client. | ||||
| func Call(ctx context.Context, request Request, response interface{}, opts ...CallOption) error { | ||||
| 	return DefaultClient.Call(ctx, request, response, opts...) | ||||
|   | ||||
| @@ -12,6 +12,24 @@ import ( | ||||
| 	"go-micro.dev/v4/transport" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// DefaultBackoff is the default backoff function for retries. | ||||
| 	DefaultBackoff = exponentialBackoff | ||||
| 	// DefaultRetry is the default check-for-retry function for retries. | ||||
| 	DefaultRetry = RetryOnError | ||||
| 	// DefaultRetries is the default number of times a request is tried. | ||||
| 	DefaultRetries = 5 | ||||
| 	// DefaultRequestTimeout is the default request timeout. | ||||
| 	DefaultRequestTimeout = time.Second * 30 | ||||
| 	// DefaultConnectionTimeout is the default connection timeout. | ||||
| 	DefaultConnectionTimeout = time.Second * 5 | ||||
| 	// DefaultPoolSize sets the connection pool size. | ||||
| 	DefaultPoolSize = 100 | ||||
| 	// DefaultPoolTTL sets the connection pool ttl. | ||||
| 	DefaultPoolTTL = time.Minute | ||||
| ) | ||||
|  | ||||
| // Options are the Client options. | ||||
| type Options struct { | ||||
| 	// Used to select codec | ||||
| 	ContentType string | ||||
| @@ -47,6 +65,7 @@ type Options struct { | ||||
| 	Context context.Context | ||||
| } | ||||
|  | ||||
| // CallOptions are options used to make calls to a server. | ||||
| type CallOptions struct { | ||||
| 	SelectOptions []selector.SelectOption | ||||
|  | ||||
| @@ -56,11 +75,14 @@ type CallOptions struct { | ||||
| 	Backoff BackoffFunc | ||||
| 	// Check if retriable func | ||||
| 	Retry RetryFunc | ||||
| 	// Transport Dial Timeout | ||||
| 	DialTimeout time.Duration | ||||
| 	// Number of Call attempts | ||||
| 	Retries int | ||||
| 	// Request/Response timeout | ||||
| 	// Transport Dial Timeout. Used for initial dial to establish a connection. | ||||
| 	DialTimeout time.Duration | ||||
| 	// ConnectionTimeout of one request to the server. | ||||
| 	// Set this lower than the RequestTimeout to enbale retries on connection timeout. | ||||
| 	ConnectionTimeout time.Duration | ||||
| 	// Request/Response timeout of entire srv.Call, for single request timeout set ConnectionTimeout. | ||||
| 	RequestTimeout time.Duration | ||||
| 	// Stream timeout for the stream | ||||
| 	StreamTimeout time.Duration | ||||
| @@ -68,6 +90,8 @@ type CallOptions struct { | ||||
| 	ServiceToken bool | ||||
| 	// Duration to cache the response for | ||||
| 	CacheExpiry time.Duration | ||||
| 	// ConnClose sets the Connection: close header. | ||||
| 	ConnClose bool | ||||
|  | ||||
| 	// Middleware for low level call func | ||||
| 	CallWrappers []CallWrapper | ||||
| @@ -98,6 +122,7 @@ type RequestOptions struct { | ||||
| 	Context context.Context | ||||
| } | ||||
|  | ||||
| // NewOptions creates new Client options. | ||||
| func NewOptions(options ...Option) Options { | ||||
| 	opts := Options{ | ||||
| 		Cache:       NewCache(), | ||||
| @@ -105,11 +130,12 @@ func NewOptions(options ...Option) Options { | ||||
| 		ContentType: DefaultContentType, | ||||
| 		Codecs:      make(map[string]codec.NewCodec), | ||||
| 		CallOptions: CallOptions{ | ||||
| 			Backoff:        DefaultBackoff, | ||||
| 			Retry:          DefaultRetry, | ||||
| 			Retries:        DefaultRetries, | ||||
| 			RequestTimeout: DefaultRequestTimeout, | ||||
| 			DialTimeout:    transport.DefaultDialTimeout, | ||||
| 			Backoff:           DefaultBackoff, | ||||
| 			Retry:             DefaultRetry, | ||||
| 			Retries:           DefaultRetries, | ||||
| 			RequestTimeout:    DefaultRequestTimeout, | ||||
| 			ConnectionTimeout: DefaultConnectionTimeout, | ||||
| 			DialTimeout:       transport.DefaultDialTimeout, | ||||
| 		}, | ||||
| 		PoolSize:  DefaultPoolSize, | ||||
| 		PoolTTL:   DefaultPoolTTL, | ||||
| @@ -141,7 +167,7 @@ func Codec(contentType string, c codec.NewCodec) Option { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Default content type of the client. | ||||
| // ContentType sets the default content type of the client. | ||||
| func ContentType(ct string) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.ContentType = ct | ||||
| @@ -207,8 +233,7 @@ func Backoff(fn BackoffFunc) Option { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Number of retries when making the request. | ||||
| // Should this be a Call Option? | ||||
| // Retries set the number of retries when making the request. | ||||
| func Retries(i int) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.CallOptions.Retries = i | ||||
| @@ -222,8 +247,7 @@ func Retry(fn RetryFunc) Option { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // The request timeout. | ||||
| // Should this be a Call Option? | ||||
| // RequestTimeout set the request timeout. | ||||
| func RequestTimeout(d time.Duration) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.CallOptions.RequestTimeout = d | ||||
| @@ -237,7 +261,7 @@ func StreamTimeout(d time.Duration) Option { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Transport dial timeout. | ||||
| // DialTimeout sets the transport dial timeout. | ||||
| func DialTimeout(d time.Duration) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.CallOptions.DialTimeout = d | ||||
| @@ -296,8 +320,8 @@ func WithRetry(fn RetryFunc) CallOption { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithRetries is a CallOption which overrides that which | ||||
| // set in Options.CallOptions. | ||||
| // WithRetries sets the number of tries for a call. | ||||
| // This CallOption overrides Options.CallOptions. | ||||
| func WithRetries(i int) CallOption { | ||||
| 	return func(o *CallOptions) { | ||||
| 		o.Retries = i | ||||
| @@ -312,6 +336,13 @@ func WithRequestTimeout(d time.Duration) CallOption { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithConnClose sets the Connection header to close. | ||||
| func WithConnClose() CallOption { | ||||
| 	return func(o *CallOptions) { | ||||
| 		o.ConnClose = true | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithStreamTimeout sets the stream timeout. | ||||
| func WithStreamTimeout(d time.Duration) CallOption { | ||||
| 	return func(o *CallOptions) { | ||||
|   | ||||
| @@ -26,8 +26,9 @@ func RetryOnError(ctx context.Context, req Request, retryCount int, err error) ( | ||||
| 	} | ||||
|  | ||||
| 	switch e.Code { | ||||
| 	// retry on timeout or internal server error | ||||
| 	case 408, 500: | ||||
| 	// Retry on timeout, not on 500 internal server error, as that is a business | ||||
| 	// logic error that should be handled by the user. | ||||
| 	case 408: | ||||
| 		return true, nil | ||||
| 	default: | ||||
| 		return false, nil | ||||
|   | ||||
| @@ -3,31 +3,42 @@ package client | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/pkg/errors" | ||||
|  | ||||
| 	"go-micro.dev/v4/broker" | ||||
| 	"go-micro.dev/v4/codec" | ||||
| 	raw "go-micro.dev/v4/codec/bytes" | ||||
| 	"go-micro.dev/v4/errors" | ||||
| 	merrors "go-micro.dev/v4/errors" | ||||
| 	log "go-micro.dev/v4/logger" | ||||
| 	"go-micro.dev/v4/metadata" | ||||
| 	"go-micro.dev/v4/registry" | ||||
| 	"go-micro.dev/v4/selector" | ||||
| 	"go-micro.dev/v4/transport" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| 	"go-micro.dev/v4/util/buf" | ||||
| 	"go-micro.dev/v4/util/net" | ||||
| 	"go-micro.dev/v4/util/pool" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	packageID = "go.micro.client" | ||||
| ) | ||||
|  | ||||
| type rpcClient struct { | ||||
| 	seq  uint64 | ||||
| 	once atomic.Value | ||||
| 	opts Options | ||||
| 	pool pool.Pool | ||||
|  | ||||
| 	mu sync.RWMutex | ||||
| } | ||||
|  | ||||
| func newRpcClient(opt ...Option) Client { | ||||
| func newRPCClient(opt ...Option) Client { | ||||
| 	opts := NewOptions(opt...) | ||||
|  | ||||
| 	p := pool.NewPool( | ||||
| @@ -57,14 +68,17 @@ func (r *rpcClient) newCodec(contentType string) (codec.NewCodec, error) { | ||||
| 	if c, ok := r.opts.Codecs[contentType]; ok { | ||||
| 		return c, nil | ||||
| 	} | ||||
|  | ||||
| 	if cf, ok := DefaultCodecs[contentType]; ok { | ||||
| 		return cf, nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, fmt.Errorf("unsupported Content-Type: %s", contentType) | ||||
| } | ||||
|  | ||||
| func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, resp interface{}, opts CallOptions) error { | ||||
| 	address := node.Address | ||||
| 	logger := r.Options().Logger | ||||
|  | ||||
| 	msg := &transport.Message{ | ||||
| 		Header: make(map[string]string), | ||||
| @@ -73,31 +87,43 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, | ||||
| 	md, ok := metadata.FromContext(ctx) | ||||
| 	if ok { | ||||
| 		for k, v := range md { | ||||
| 			// don't copy Micro-Topic header, that used for pub/sub | ||||
| 			// this fix case then client uses the same context that received in subscriber | ||||
| 			if k == "Micro-Topic" { | ||||
| 			// Don't copy Micro-Topic header, that is used for pub/sub | ||||
| 			// this is fixes the case when the client uses the same context that | ||||
| 			// is received in the subscriber. | ||||
| 			if k == headers.Message { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			msg.Header[k] = v | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Set connection timeout for single requests to the server. Should be > 0 | ||||
| 	// as otherwise requests can't be made. | ||||
| 	cTimeout := opts.ConnectionTimeout | ||||
| 	if cTimeout == 0 { | ||||
| 		logger.Log(log.DebugLevel, "connection timeout was set to 0, overridng to default connection timeout") | ||||
|  | ||||
| 		cTimeout = DefaultConnectionTimeout | ||||
| 	} | ||||
|  | ||||
| 	// set timeout in nanoseconds | ||||
| 	msg.Header["Timeout"] = fmt.Sprintf("%d", opts.RequestTimeout) | ||||
| 	msg.Header["Timeout"] = fmt.Sprintf("%d", cTimeout) | ||||
| 	// set the content type for the request | ||||
| 	msg.Header["Content-Type"] = req.ContentType() | ||||
| 	// set the accept header | ||||
| 	msg.Header["Accept"] = req.ContentType() | ||||
|  | ||||
| 	// setup old protocol | ||||
| 	cf := setupProtocol(msg, node) | ||||
| 	reqCodec := setupProtocol(msg, node) | ||||
|  | ||||
| 	// no codec specified | ||||
| 	if cf == nil { | ||||
| 	if reqCodec == nil { | ||||
| 		var err error | ||||
| 		cf, err = r.newCodec(req.ContentType()) | ||||
| 		reqCodec, err = r.newCodec(req.ContentType()) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return errors.InternalServerError("go.micro.client", err.Error()) | ||||
| 			return merrors.InternalServerError("go.micro.client", err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -109,19 +135,29 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, | ||||
| 		dOpts = append(dOpts, transport.WithTimeout(opts.DialTimeout)) | ||||
| 	} | ||||
|  | ||||
| 	if opts.ConnClose { | ||||
| 		dOpts = append(dOpts, transport.WithConnClose()) | ||||
| 	} | ||||
|  | ||||
| 	c, err := r.pool.Get(address, dOpts...) | ||||
| 	if err != nil { | ||||
| 		return errors.InternalServerError("go.micro.client", "connection error: %v", err) | ||||
| 		return merrors.InternalServerError("go.micro.client", "connection error: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	seq := atomic.AddUint64(&r.seq, 1) - 1 | ||||
| 	codec := newRpcCodec(msg, c, cf, "") | ||||
| 	codec := newRPCCodec(msg, c, reqCodec, "") | ||||
|  | ||||
| 	rsp := &rpcResponse{ | ||||
| 		socket: c, | ||||
| 		codec:  codec, | ||||
| 	} | ||||
|  | ||||
| 	releaseFunc := func(err error) { | ||||
| 		if err = r.pool.Release(c, err); err != nil { | ||||
| 			logger.Log(log.ErrorLevel, "failed to release pool", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	stream := &rpcStream{ | ||||
| 		id:       fmt.Sprintf("%v", seq), | ||||
| 		context:  ctx, | ||||
| @@ -129,11 +165,17 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, | ||||
| 		response: rsp, | ||||
| 		codec:    codec, | ||||
| 		closed:   make(chan bool), | ||||
| 		release:  func(err error) { r.pool.Release(c, err) }, | ||||
| 		close:    opts.ConnClose, | ||||
| 		release:  releaseFunc, | ||||
| 		sendEOS:  false, | ||||
| 	} | ||||
|  | ||||
| 	// close the stream on exiting this function | ||||
| 	defer stream.Close() | ||||
| 	defer func() { | ||||
| 		if err := stream.Close(); err != nil { | ||||
| 			logger.Log(log.ErrorLevel, "failed to close stream", err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// wait for error response | ||||
| 	ch := make(chan error, 1) | ||||
| @@ -141,7 +183,7 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, | ||||
| 	go func() { | ||||
| 		defer func() { | ||||
| 			if r := recover(); r != nil { | ||||
| 				ch <- errors.InternalServerError("go.micro.client", "panic recovered: %v", r) | ||||
| 				ch <- merrors.InternalServerError("go.micro.client", "panic recovered: %v", r) | ||||
| 			} | ||||
| 		}() | ||||
|  | ||||
| @@ -166,8 +208,8 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, | ||||
| 	select { | ||||
| 	case err := <-ch: | ||||
| 		return err | ||||
| 	case <-ctx.Done(): | ||||
| 		grr = errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) | ||||
| 	case <-time.After(cTimeout): | ||||
| 		grr = merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) | ||||
| 	} | ||||
|  | ||||
| 	// set the stream error | ||||
| @@ -184,6 +226,7 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, | ||||
|  | ||||
| func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request, opts CallOptions) (Stream, error) { | ||||
| 	address := node.Address | ||||
| 	logger := r.Options().Logger | ||||
|  | ||||
| 	msg := &transport.Message{ | ||||
| 		Header: make(map[string]string), | ||||
| @@ -206,14 +249,15 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request | ||||
| 	msg.Header["Accept"] = req.ContentType() | ||||
|  | ||||
| 	// set old codecs | ||||
| 	cf := setupProtocol(msg, node) | ||||
| 	nCodec := setupProtocol(msg, node) | ||||
|  | ||||
| 	// no codec specified | ||||
| 	if cf == nil { | ||||
| 	if nCodec == nil { | ||||
| 		var err error | ||||
| 		cf, err = r.newCodec(req.ContentType()) | ||||
|  | ||||
| 		nCodec, err = r.newCodec(req.ContentType()) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.InternalServerError("go.micro.client", err.Error()) | ||||
| 			return nil, merrors.InternalServerError("go.micro.client", err.Error()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -227,7 +271,7 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request | ||||
|  | ||||
| 	c, err := r.opts.Transport.Dial(address, dOpts...) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.InternalServerError("go.micro.client", "connection error: %v", err) | ||||
| 		return nil, merrors.InternalServerError("go.micro.client", "connection error: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// increment the sequence number | ||||
| @@ -235,7 +279,7 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request | ||||
| 	id := fmt.Sprintf("%v", seq) | ||||
|  | ||||
| 	// create codec with stream id | ||||
| 	codec := newRpcCodec(msg, c, cf, id) | ||||
| 	codec := newRPCCodec(msg, c, nCodec, id) | ||||
|  | ||||
| 	rsp := &rpcResponse{ | ||||
| 		socket: c, | ||||
| @@ -247,6 +291,12 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request | ||||
| 		r.codec = codec | ||||
| 	} | ||||
|  | ||||
| 	releaseFunc := func(_ error) { | ||||
| 		if err = c.Close(); err != nil { | ||||
| 			logger.Log(log.ErrorLevel, err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	stream := &rpcStream{ | ||||
| 		id:       id, | ||||
| 		context:  ctx, | ||||
| @@ -257,8 +307,7 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request | ||||
| 		closed: make(chan bool), | ||||
| 		// signal the end of stream, | ||||
| 		sendEOS: true, | ||||
| 		// release func | ||||
| 		release: func(err error) { c.Close() }, | ||||
| 		release: releaseFunc, | ||||
| 	} | ||||
|  | ||||
| 	// wait for error response | ||||
| @@ -275,7 +324,7 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request | ||||
| 	case err := <-ch: | ||||
| 		grr = err | ||||
| 	case <-ctx.Done(): | ||||
| 		grr = errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) | ||||
| 		grr = merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) | ||||
| 	} | ||||
|  | ||||
| 	if grr != nil { | ||||
| @@ -285,7 +334,10 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request | ||||
| 		stream.Unlock() | ||||
|  | ||||
| 		// close the stream | ||||
| 		stream.Close() | ||||
| 		if err := stream.Close(); err != nil { | ||||
| 			logger.Logf(log.ErrorLevel, "failed to close stream: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		return nil, grr | ||||
| 	} | ||||
|  | ||||
| @@ -293,6 +345,9 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request | ||||
| } | ||||
|  | ||||
| func (r *rpcClient) Init(opts ...Option) error { | ||||
| 	r.mu.Lock() | ||||
| 	defer r.mu.Unlock() | ||||
|  | ||||
| 	size := r.opts.PoolSize | ||||
| 	ttl := r.opts.PoolTTL | ||||
| 	tr := r.opts.Transport | ||||
| @@ -304,7 +359,10 @@ func (r *rpcClient) Init(opts ...Option) error { | ||||
| 	// update pool configuration if the options changed | ||||
| 	if size != r.opts.PoolSize || ttl != r.opts.PoolTTL || tr != r.opts.Transport { | ||||
| 		// close existing pool | ||||
| 		r.pool.Close() | ||||
| 		if err := r.pool.Close(); err != nil { | ||||
| 			return errors.Wrap(err, "failed to close pool") | ||||
| 		} | ||||
|  | ||||
| 		// create new pool | ||||
| 		r.pool = pool.NewPool( | ||||
| 			pool.Size(r.opts.PoolSize), | ||||
| @@ -316,7 +374,11 @@ func (r *rpcClient) Init(opts ...Option) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Options retrives the options. | ||||
| func (r *rpcClient) Options() Options { | ||||
| 	r.mu.RLock() | ||||
| 	defer r.mu.RUnlock() | ||||
|  | ||||
| 	return r.opts | ||||
| } | ||||
|  | ||||
| @@ -348,16 +410,22 @@ func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, erro | ||||
| 	// get next nodes from the selector | ||||
| 	next, err := r.opts.Selector.Select(service, opts.SelectOptions...) | ||||
| 	if err != nil { | ||||
| 		if err == selector.ErrNotFound { | ||||
| 			return nil, errors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error()) | ||||
| 		if errors.Is(err, selector.ErrNotFound) { | ||||
| 			return nil, merrors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error()) | ||||
| 		} | ||||
| 		return nil, errors.InternalServerError("go.micro.client", "error selecting %s node: %s", service, err.Error()) | ||||
|  | ||||
| 		return nil, merrors.InternalServerError("go.micro.client", "error selecting %s node: %s", service, err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	return next, nil | ||||
| } | ||||
|  | ||||
| func (r *rpcClient) Call(ctx context.Context, request Request, response interface{}, opts ...CallOption) error { | ||||
| 	// TODO: further validate these mutex locks. full lock would prevent | ||||
| 	// parallel calls. Maybe we can set individual locks for secctions. | ||||
| 	r.mu.RLock() | ||||
| 	defer r.mu.RUnlock() | ||||
|  | ||||
| 	// make a copy of call opts | ||||
| 	callOpts := r.opts.CallOptions | ||||
| 	for _, opt := range opts { | ||||
| @@ -375,6 +443,7 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac | ||||
| 		// no deadline so we create a new one | ||||
| 		var cancel context.CancelFunc | ||||
| 		ctx, cancel = context.WithTimeout(ctx, callOpts.RequestTimeout) | ||||
|  | ||||
| 		defer cancel() | ||||
| 	} else { | ||||
| 		// got a deadline so no need to setup context | ||||
| @@ -386,7 +455,7 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac | ||||
| 	// should we noop right here? | ||||
| 	select { | ||||
| 	case <-ctx.Done(): | ||||
| 		return errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) | ||||
| 		return merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) | ||||
| 	default: | ||||
| 	} | ||||
|  | ||||
| @@ -403,7 +472,7 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac | ||||
| 		// call backoff first. Someone may want an initial start delay | ||||
| 		t, err := callOpts.Backoff(ctx, request, i) | ||||
| 		if err != nil { | ||||
| 			return errors.InternalServerError("go.micro.client", "backoff error: %v", err.Error()) | ||||
| 			return merrors.InternalServerError("go.micro.client", "backoff error: %v", err.Error()) | ||||
| 		} | ||||
|  | ||||
| 		// only sleep if greater than 0 | ||||
| @@ -414,16 +483,19 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac | ||||
| 		// select next node | ||||
| 		node, err := next() | ||||
| 		service := request.Service() | ||||
|  | ||||
| 		if err != nil { | ||||
| 			if err == selector.ErrNotFound { | ||||
| 				return errors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error()) | ||||
| 			if errors.Is(err, selector.ErrNotFound) { | ||||
| 				return merrors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error()) | ||||
| 			} | ||||
| 			return errors.InternalServerError("go.micro.client", "error getting next %s node: %s", service, err.Error()) | ||||
|  | ||||
| 			return merrors.InternalServerError("go.micro.client", "error getting next %s node: %s", service, err.Error()) | ||||
| 		} | ||||
|  | ||||
| 		// make the call | ||||
| 		err = rcall(ctx, node, request, response, callOpts) | ||||
| 		r.opts.Selector.Mark(service, node, err) | ||||
|  | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| @@ -431,11 +503,13 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac | ||||
| 	retries := callOpts.Retries | ||||
|  | ||||
| 	// disable retries when using a proxy | ||||
| 	if _, _, ok := net.Proxy(request.Service(), callOpts.Address); ok { | ||||
| 		retries = 0 | ||||
| 	} | ||||
| 	// Note: I don't see why we should disable retries for proxies, so commenting out. | ||||
| 	// if _, _, ok := net.Proxy(request.Service(), callOpts.Address); ok { | ||||
| 	// 	retries = 0 | ||||
| 	// } | ||||
|  | ||||
| 	ch := make(chan error, retries+1) | ||||
|  | ||||
| 	var gerr error | ||||
|  | ||||
| 	for i := 0; i <= retries; i++ { | ||||
| @@ -445,7 +519,7 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac | ||||
|  | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			return errors.Timeout("go.micro.client", fmt.Sprintf("call timeout: %v", ctx.Err())) | ||||
| 			return merrors.Timeout("go.micro.client", fmt.Sprintf("call timeout: %v", ctx.Err())) | ||||
| 		case err := <-ch: | ||||
| 			// if the call succeeded lets bail early | ||||
| 			if err == nil { | ||||
| @@ -461,6 +535,8 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac | ||||
| 				return err | ||||
| 			} | ||||
|  | ||||
| 			r.opts.Logger.Logf(log.DebugLevel, "Retrying request. Previous attempt failed with: %v", err) | ||||
|  | ||||
| 			gerr = err | ||||
| 		} | ||||
| 	} | ||||
| @@ -469,6 +545,9 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac | ||||
| } | ||||
|  | ||||
| func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOption) (Stream, error) { | ||||
| 	r.mu.RLock() | ||||
| 	defer r.mu.RUnlock() | ||||
|  | ||||
| 	// make a copy of call opts | ||||
| 	callOpts := r.opts.CallOptions | ||||
| 	for _, opt := range opts { | ||||
| @@ -480,10 +559,9 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// should we noop right here? | ||||
| 	select { | ||||
| 	case <-ctx.Done(): | ||||
| 		return nil, errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) | ||||
| 		return nil, merrors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) | ||||
| 	default: | ||||
| 	} | ||||
|  | ||||
| @@ -491,7 +569,7 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt | ||||
| 		// call backoff first. Someone may want an initial start delay | ||||
| 		t, err := callOpts.Backoff(ctx, request, i) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.InternalServerError("go.micro.client", "backoff error: %v", err.Error()) | ||||
| 			return nil, merrors.InternalServerError("go.micro.client", "backoff error: %v", err.Error()) | ||||
| 		} | ||||
|  | ||||
| 		// only sleep if greater than 0 | ||||
| @@ -501,15 +579,18 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt | ||||
|  | ||||
| 		node, err := next() | ||||
| 		service := request.Service() | ||||
|  | ||||
| 		if err != nil { | ||||
| 			if err == selector.ErrNotFound { | ||||
| 				return nil, errors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error()) | ||||
| 			if errors.Is(err, selector.ErrNotFound) { | ||||
| 				return nil, merrors.InternalServerError("go.micro.client", "service %s: %s", service, err.Error()) | ||||
| 			} | ||||
| 			return nil, errors.InternalServerError("go.micro.client", "error getting next %s node: %s", service, err.Error()) | ||||
|  | ||||
| 			return nil, merrors.InternalServerError("go.micro.client", "error getting next %s node: %s", service, err.Error()) | ||||
| 		} | ||||
|  | ||||
| 		stream, err := r.stream(ctx, node, request, callOpts) | ||||
| 		r.opts.Selector.Mark(service, node, err) | ||||
|  | ||||
| 		return stream, err | ||||
| 	} | ||||
|  | ||||
| @@ -527,6 +608,7 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt | ||||
| 	} | ||||
|  | ||||
| 	ch := make(chan response, retries+1) | ||||
|  | ||||
| 	var grr error | ||||
|  | ||||
| 	for i := 0; i <= retries; i++ { | ||||
| @@ -537,7 +619,7 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt | ||||
|  | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			return nil, errors.Timeout("go.micro.client", fmt.Sprintf("call timeout: %v", ctx.Err())) | ||||
| 			return nil, merrors.Timeout("go.micro.client", fmt.Sprintf("call timeout: %v", ctx.Err())) | ||||
| 		case rsp := <-ch: | ||||
| 			// if the call succeeded lets bail early | ||||
| 			if rsp.err == nil { | ||||
| @@ -568,15 +650,15 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt | ||||
| 		o(&options) | ||||
| 	} | ||||
|  | ||||
| 	md, ok := metadata.FromContext(ctx) | ||||
| 	metadata, ok := metadata.FromContext(ctx) | ||||
| 	if !ok { | ||||
| 		md = make(map[string]string) | ||||
| 		metadata = make(map[string]string) | ||||
| 	} | ||||
|  | ||||
| 	id := uuid.New().String() | ||||
| 	md["Content-Type"] = msg.ContentType() | ||||
| 	md["Micro-Topic"] = msg.Topic() | ||||
| 	md["Micro-Id"] = id | ||||
| 	metadata["Content-Type"] = msg.ContentType() | ||||
| 	metadata[headers.Message] = msg.Topic() | ||||
| 	metadata[headers.ID] = id | ||||
|  | ||||
| 	// set the topic | ||||
| 	topic := msg.Topic() | ||||
| @@ -589,7 +671,7 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt | ||||
| 	// encode message body | ||||
| 	cf, err := r.newCodec(msg.ContentType()) | ||||
| 	if err != nil { | ||||
| 		return errors.InternalServerError("go.micro.client", err.Error()) | ||||
| 		return merrors.InternalServerError(packageID, err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	var body []byte | ||||
| @@ -598,33 +680,38 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt | ||||
| 	if d, ok := msg.Payload().(*raw.Frame); ok { | ||||
| 		body = d.Data | ||||
| 	} else { | ||||
| 		// new buffer | ||||
| 		b := buf.New(nil) | ||||
|  | ||||
| 		if err := cf(b).Write(&codec.Message{ | ||||
| 		if err = cf(b).Write(&codec.Message{ | ||||
| 			Target: topic, | ||||
| 			Type:   codec.Event, | ||||
| 			Header: map[string]string{ | ||||
| 				"Micro-Id":    id, | ||||
| 				"Micro-Topic": msg.Topic(), | ||||
| 				headers.ID:      id, | ||||
| 				headers.Message: msg.Topic(), | ||||
| 			}, | ||||
| 		}, msg.Payload()); err != nil { | ||||
| 			return errors.InternalServerError("go.micro.client", err.Error()) | ||||
| 			return merrors.InternalServerError(packageID, err.Error()) | ||||
| 		} | ||||
|  | ||||
| 		// set the body | ||||
| 		body = b.Bytes() | ||||
| 	} | ||||
|  | ||||
| 	if !r.once.Load().(bool) { | ||||
| 	l, ok := r.once.Load().(bool) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("failed to cast to bool") | ||||
| 	} | ||||
|  | ||||
| 	if !l { | ||||
| 		if err = r.opts.Broker.Connect(); err != nil { | ||||
| 			return errors.InternalServerError("go.micro.client", err.Error()) | ||||
| 			return merrors.InternalServerError(packageID, err.Error()) | ||||
| 		} | ||||
|  | ||||
| 		r.once.Store(true) | ||||
| 	} | ||||
|  | ||||
| 	return r.opts.Broker.Publish(topic, &broker.Message{ | ||||
| 		Header: md, | ||||
| 		Header: metadata, | ||||
| 		Body:   body, | ||||
| 	}, broker.PublishContext(options.Context)) | ||||
| } | ||||
|   | ||||
| @@ -10,18 +10,23 @@ import ( | ||||
| 	"go-micro.dev/v4/selector" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	serviceName     = "test.service" | ||||
| 	serviceEndpoint = "Test.Endpoint" | ||||
| ) | ||||
|  | ||||
| func newTestRegistry() registry.Registry { | ||||
| 	return registry.NewMemoryRegistry(registry.Services(testData)) | ||||
| } | ||||
|  | ||||
| func TestCallAddress(t *testing.T) { | ||||
| 	var called bool | ||||
| 	service := "test.service" | ||||
| 	endpoint := "Test.Endpoint" | ||||
| 	service := serviceName | ||||
| 	endpoint := serviceEndpoint | ||||
| 	address := "10.1.10.1:8080" | ||||
|  | ||||
| 	wrap := func(cf CallFunc) CallFunc { | ||||
| 		return func(ctx context.Context, node *registry.Node, req Request, rsp interface{}, opts CallOptions) error { | ||||
| 		return func(_ context.Context, node *registry.Node, req Request, _ interface{}, _ CallOptions) error { | ||||
| 			called = true | ||||
|  | ||||
| 			if req.Service() != service { | ||||
| @@ -46,7 +51,10 @@ func TestCallAddress(t *testing.T) { | ||||
| 		Registry(r), | ||||
| 		WrapCall(wrap), | ||||
| 	) | ||||
| 	c.Options().Selector.Init(selector.Registry(r)) | ||||
|  | ||||
| 	if err := c.Options().Selector.Init(selector.Registry(r)); err != nil { | ||||
| 		t.Fatal("failed to initialize selector", err) | ||||
| 	} | ||||
|  | ||||
| 	req := c.NewRequest(service, endpoint, nil) | ||||
|  | ||||
| @@ -68,7 +76,7 @@ func TestCallRetry(t *testing.T) { | ||||
| 	var called int | ||||
|  | ||||
| 	wrap := func(cf CallFunc) CallFunc { | ||||
| 		return func(ctx context.Context, node *registry.Node, req Request, rsp interface{}, opts CallOptions) error { | ||||
| 		return func(_ context.Context, _ *registry.Node, _ Request, _ interface{}, _ CallOptions) error { | ||||
| 			called++ | ||||
| 			if called == 1 { | ||||
| 				return errors.InternalServerError("test.error", "retry request") | ||||
| @@ -84,7 +92,10 @@ func TestCallRetry(t *testing.T) { | ||||
| 		Registry(r), | ||||
| 		WrapCall(wrap), | ||||
| 	) | ||||
| 	c.Options().Selector.Init(selector.Registry(r)) | ||||
|  | ||||
| 	if err := c.Options().Selector.Init(selector.Registry(r)); err != nil { | ||||
| 		t.Fatal("failed to initialize selector", err) | ||||
| 	} | ||||
|  | ||||
| 	req := c.NewRequest(service, endpoint, nil) | ||||
|  | ||||
| @@ -107,7 +118,7 @@ func TestCallWrapper(t *testing.T) { | ||||
| 	address := "10.1.10.1:8080" | ||||
|  | ||||
| 	wrap := func(cf CallFunc) CallFunc { | ||||
| 		return func(ctx context.Context, node *registry.Node, req Request, rsp interface{}, opts CallOptions) error { | ||||
| 		return func(_ context.Context, node *registry.Node, req Request, _ interface{}, _ CallOptions) error { | ||||
| 			called = true | ||||
|  | ||||
| 			if req.Service() != service { | ||||
| @@ -132,9 +143,12 @@ func TestCallWrapper(t *testing.T) { | ||||
| 		Registry(r), | ||||
| 		WrapCall(wrap), | ||||
| 	) | ||||
| 	c.Options().Selector.Init(selector.Registry(r)) | ||||
|  | ||||
| 	r.Register(®istry.Service{ | ||||
| 	if err := c.Options().Selector.Init(selector.Registry(r)); err != nil { | ||||
| 		t.Fatal("failed to initialize selector", err) | ||||
| 	} | ||||
|  | ||||
| 	err := r.Register(®istry.Service{ | ||||
| 		Name:    service, | ||||
| 		Version: "latest", | ||||
| 		Nodes: []*registry.Node{ | ||||
| @@ -147,6 +161,9 @@ func TestCallWrapper(t *testing.T) { | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal("failed to register service", err) | ||||
| 	} | ||||
|  | ||||
| 	req := c.NewRequest(service, endpoint, nil) | ||||
| 	if err := c.Call(context.Background(), req, nil); err != nil { | ||||
|   | ||||
| @@ -14,6 +14,7 @@ import ( | ||||
| 	"go-micro.dev/v4/errors" | ||||
| 	"go-micro.dev/v4/registry" | ||||
| 	"go-micro.dev/v4/transport" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -50,8 +51,10 @@ type readWriteCloser struct { | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	// DefaultContentType header. | ||||
| 	DefaultContentType = "application/json" | ||||
|  | ||||
| 	// DefaultCodecs map. | ||||
| 	DefaultCodecs = map[string]codec.NewCodec{ | ||||
| 		"application/grpc":         grpc.NewCodec, | ||||
| 		"application/grpc+json":    grpc.NewCodec, | ||||
| @@ -84,6 +87,7 @@ func (rwc *readWriteCloser) Write(p []byte) (n int, err error) { | ||||
| func (rwc *readWriteCloser) Close() error { | ||||
| 	rwc.rbuf.Reset() | ||||
| 	rwc.wbuf.Reset() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -92,20 +96,21 @@ func getHeaders(m *codec.Message) { | ||||
| 		if len(v) > 0 { | ||||
| 			return v | ||||
| 		} | ||||
|  | ||||
| 		return m.Header[hdr] | ||||
| 	} | ||||
|  | ||||
| 	// check error in header | ||||
| 	m.Error = set(m.Error, "Micro-Error") | ||||
| 	m.Error = set(m.Error, headers.Error) | ||||
|  | ||||
| 	// check endpoint in header | ||||
| 	m.Endpoint = set(m.Endpoint, "Micro-Endpoint") | ||||
| 	m.Endpoint = set(m.Endpoint, headers.Endpoint) | ||||
|  | ||||
| 	// check method in header | ||||
| 	m.Method = set(m.Method, "Micro-Method") | ||||
| 	m.Method = set(m.Method, headers.Method) | ||||
|  | ||||
| 	// set the request id | ||||
| 	m.Id = set(m.Id, "Micro-Id") | ||||
| 	m.Id = set(m.Id, headers.ID) | ||||
| } | ||||
|  | ||||
| func setHeaders(m *codec.Message, stream string) { | ||||
| @@ -113,17 +118,18 @@ func setHeaders(m *codec.Message, stream string) { | ||||
| 		if len(v) == 0 { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		m.Header[hdr] = v | ||||
| 	} | ||||
|  | ||||
| 	set("Micro-Id", m.Id) | ||||
| 	set("Micro-Service", m.Target) | ||||
| 	set("Micro-Method", m.Method) | ||||
| 	set("Micro-Endpoint", m.Endpoint) | ||||
| 	set("Micro-Error", m.Error) | ||||
| 	set(headers.ID, m.Id) | ||||
| 	set(headers.Request, m.Target) | ||||
| 	set(headers.Method, m.Method) | ||||
| 	set(headers.Endpoint, m.Endpoint) | ||||
| 	set(headers.Error, m.Error) | ||||
|  | ||||
| 	if len(stream) > 0 { | ||||
| 		set("Micro-Stream", stream) | ||||
| 		set(headers.Stream, stream) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -137,7 +143,7 @@ func setupProtocol(msg *transport.Message, node *registry.Node) codec.NewCodec { | ||||
| 	} | ||||
|  | ||||
| 	// processing topic publishing | ||||
| 	if len(msg.Header["Micro-Topic"]) > 0 { | ||||
| 	if len(msg.Header[headers.Message]) > 0 { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| @@ -149,60 +155,59 @@ func setupProtocol(msg *transport.Message, node *registry.Node) codec.NewCodec { | ||||
| 		msg.Header["Content-Type"] = "application/proto-rpc" | ||||
| 	} | ||||
|  | ||||
| 	// now return codec | ||||
| 	return defaultCodecs[msg.Header["Content-Type"]] | ||||
| } | ||||
|  | ||||
| func newRpcCodec(req *transport.Message, client transport.Client, c codec.NewCodec, stream string) codec.Codec { | ||||
| func newRPCCodec(req *transport.Message, client transport.Client, c codec.NewCodec, stream string) codec.Codec { | ||||
| 	rwc := &readWriteCloser{ | ||||
| 		wbuf: bytes.NewBuffer(nil), | ||||
| 		rbuf: bytes.NewBuffer(nil), | ||||
| 	} | ||||
| 	r := &rpcCodec{ | ||||
|  | ||||
| 	return &rpcCodec{ | ||||
| 		buf:    rwc, | ||||
| 		client: client, | ||||
| 		codec:  c(rwc), | ||||
| 		req:    req, | ||||
| 		stream: stream, | ||||
| 	} | ||||
| 	return r | ||||
| } | ||||
|  | ||||
| func (c *rpcCodec) Write(m *codec.Message, body interface{}) error { | ||||
| func (c *rpcCodec) Write(message *codec.Message, body interface{}) error { | ||||
| 	c.buf.wbuf.Reset() | ||||
|  | ||||
| 	// create header | ||||
| 	if m.Header == nil { | ||||
| 		m.Header = map[string]string{} | ||||
| 	if message.Header == nil { | ||||
| 		message.Header = map[string]string{} | ||||
| 	} | ||||
|  | ||||
| 	// copy original header | ||||
| 	for k, v := range c.req.Header { | ||||
| 		m.Header[k] = v | ||||
| 		message.Header[k] = v | ||||
| 	} | ||||
|  | ||||
| 	// set the mucp headers | ||||
| 	setHeaders(m, c.stream) | ||||
| 	setHeaders(message, c.stream) | ||||
|  | ||||
| 	// if body is bytes Frame don't encode | ||||
| 	if body != nil { | ||||
| 		if b, ok := body.(*raw.Frame); ok { | ||||
| 			// set body | ||||
| 			m.Body = b.Data | ||||
| 			message.Body = b.Data | ||||
| 		} else { | ||||
| 			// write to codec | ||||
| 			if err := c.codec.Write(m, body); err != nil { | ||||
| 			if err := c.codec.Write(message, body); err != nil { | ||||
| 				return errors.InternalServerError("go.micro.client.codec", err.Error()) | ||||
| 			} | ||||
| 			// set body | ||||
| 			m.Body = c.buf.wbuf.Bytes() | ||||
| 			message.Body = c.buf.wbuf.Bytes() | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// create new transport message | ||||
| 	msg := transport.Message{ | ||||
| 		Header: m.Header, | ||||
| 		Body:   m.Body, | ||||
| 		Header: message.Header, | ||||
| 		Body:   message.Body, | ||||
| 	} | ||||
|  | ||||
| 	// send the request | ||||
| @@ -213,7 +218,7 @@ func (c *rpcCodec) Write(m *codec.Message, body interface{}) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *rpcCodec) ReadHeader(m *codec.Message, r codec.MessageType) error { | ||||
| func (c *rpcCodec) ReadHeader(msg *codec.Message, r codec.MessageType) error { | ||||
| 	var tm transport.Message | ||||
|  | ||||
| 	// read message from transport | ||||
| @@ -225,13 +230,13 @@ func (c *rpcCodec) ReadHeader(m *codec.Message, r codec.MessageType) error { | ||||
| 	c.buf.rbuf.Write(tm.Body) | ||||
|  | ||||
| 	// set headers from transport | ||||
| 	m.Header = tm.Header | ||||
| 	msg.Header = tm.Header | ||||
|  | ||||
| 	// read header | ||||
| 	err := c.codec.ReadHeader(m, r) | ||||
| 	err := c.codec.ReadHeader(msg, r) | ||||
|  | ||||
| 	// get headers | ||||
| 	getHeaders(m) | ||||
| 	getHeaders(msg) | ||||
|  | ||||
| 	// return header error | ||||
| 	if err != nil { | ||||
| @@ -252,15 +257,23 @@ func (c *rpcCodec) ReadBody(b interface{}) error { | ||||
| 	if err := c.codec.ReadBody(b); err != nil { | ||||
| 		return errors.InternalServerError("go.micro.client.codec", err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *rpcCodec) Close() error { | ||||
| 	c.buf.Close() | ||||
| 	c.codec.Close() | ||||
| 	if err := c.buf.Close(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if err := c.codec.Close(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if err := c.client.Close(); err != nil { | ||||
| 		return errors.InternalServerError("go.micro.client.transport", err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -12,8 +12,10 @@ import ( | ||||
| // Implements the streamer interface. | ||||
| type rpcStream struct { | ||||
| 	sync.RWMutex | ||||
| 	id       string | ||||
| 	closed   chan bool | ||||
| 	id     string | ||||
| 	closed chan bool | ||||
| 	// Indicates whether connection should be closed directly. | ||||
| 	close    bool | ||||
| 	err      error | ||||
| 	request  Request | ||||
| 	response Response | ||||
| @@ -79,6 +81,7 @@ func (r *rpcStream) Recv(msg interface{}) error { | ||||
| 	if r.isClosed() { | ||||
| 		r.err = errShutdown | ||||
| 		r.Unlock() | ||||
|  | ||||
| 		return errShutdown | ||||
| 	} | ||||
|  | ||||
| @@ -87,15 +90,19 @@ func (r *rpcStream) Recv(msg interface{}) error { | ||||
| 	r.Unlock() | ||||
| 	err := r.codec.ReadHeader(&resp, codec.Response) | ||||
| 	r.Lock() | ||||
|  | ||||
| 	if err != nil { | ||||
| 		if err == io.EOF && !r.isClosed() { | ||||
| 		if errors.Is(err, io.EOF) && !r.isClosed() { | ||||
| 			r.err = io.ErrUnexpectedEOF | ||||
| 			r.Unlock() | ||||
|  | ||||
| 			return io.ErrUnexpectedEOF | ||||
| 		} | ||||
|  | ||||
| 		r.err = err | ||||
|  | ||||
| 		r.Unlock() | ||||
|  | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| @@ -124,13 +131,15 @@ func (r *rpcStream) Recv(msg interface{}) error { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	r.Unlock() | ||||
| 	defer r.Unlock() | ||||
|  | ||||
| 	return r.err | ||||
| } | ||||
|  | ||||
| func (r *rpcStream) Error() error { | ||||
| 	r.RLock() | ||||
| 	defer r.RUnlock() | ||||
|  | ||||
| 	return r.err | ||||
| } | ||||
|  | ||||
| @@ -152,6 +161,7 @@ func (r *rpcStream) Close() error { | ||||
| 		// send the end of stream message | ||||
| 		if r.sendEOS { | ||||
| 			// no need to check for error | ||||
| 			//nolint:errcheck,gosec | ||||
| 			r.codec.Write(&codec.Message{ | ||||
| 				Id:       r.id, | ||||
| 				Target:   r.request.Service(), | ||||
| @@ -164,10 +174,13 @@ func (r *rpcStream) Close() error { | ||||
|  | ||||
| 		err := r.codec.Close() | ||||
|  | ||||
| 		rerr := r.Error() | ||||
| 		if r.close && rerr == nil { | ||||
| 			rerr = errors.New("connection header set to close") | ||||
| 		} | ||||
| 		// release the connection | ||||
| 		r.release(r.Error()) | ||||
| 		r.release(rerr) | ||||
|  | ||||
| 		// return the codec error | ||||
| 		return err | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
|  | ||||
| 	"github.com/golang/protobuf/proto" | ||||
| 	"go-micro.dev/v4/codec" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| ) | ||||
|  | ||||
| type Codec struct { | ||||
| @@ -29,8 +30,8 @@ func (c *Codec) ReadHeader(m *codec.Message, t codec.MessageType) error { | ||||
| 	// service method | ||||
| 	path := m.Header[":path"] | ||||
| 	if len(path) == 0 || path[0] != '/' { | ||||
| 		m.Target = m.Header["Micro-Service"] | ||||
| 		m.Endpoint = m.Header["Micro-Endpoint"] | ||||
| 		m.Target = m.Header[headers.Request] | ||||
| 		m.Endpoint = m.Header[headers.Endpoint] | ||||
| 	} else { | ||||
| 		// [ , a.package.Foo, Bar] | ||||
| 		parts := strings.Split(path, "/") | ||||
|   | ||||
| @@ -3,6 +3,8 @@ package handler | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"time" | ||||
|  | ||||
| 	"go-micro.dev/v4/client" | ||||
| @@ -10,7 +12,6 @@ import ( | ||||
| 	proto "go-micro.dev/v4/debug/proto" | ||||
| 	"go-micro.dev/v4/debug/stats" | ||||
| 	"go-micro.dev/v4/debug/trace" | ||||
| 	"go-micro.dev/v4/server" | ||||
| ) | ||||
|  | ||||
| // NewHandler returns an instance of the Debug Handler. | ||||
| @@ -22,6 +23,8 @@ func NewHandler(c client.Client) *Debug { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| var _ proto.DebugHandler = (*Debug)(nil) | ||||
|  | ||||
| type Debug struct { | ||||
| 	// must honor the debug handler | ||||
| 	proto.DebugHandler | ||||
| @@ -38,6 +41,25 @@ func (d *Debug) Health(ctx context.Context, req *proto.HealthRequest, rsp *proto | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (d *Debug) MessageBus(ctx context.Context, stream proto.Debug_MessageBusStream) error { | ||||
| 	for { | ||||
| 		_, err := stream.Recv() | ||||
| 		if errors.Is(err, io.EOF) { | ||||
| 			return nil | ||||
| 		} else if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		rsp := proto.BusMsg{ | ||||
| 			Msg: "Request received!", | ||||
| 		} | ||||
|  | ||||
| 		if err := stream.Send(&rsp); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (d *Debug) Stats(ctx context.Context, req *proto.StatsRequest, rsp *proto.StatsResponse) error { | ||||
| 	stats, err := d.stats.Read() | ||||
| 	if err != nil { | ||||
| @@ -92,11 +114,7 @@ func (d *Debug) Trace(ctx context.Context, req *proto.TraceRequest, rsp *proto.T | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (d *Debug) Log(ctx context.Context, stream server.Stream) error { | ||||
| 	req := new(proto.LogRequest) | ||||
| 	if err := stream.Recv(req); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| func (d *Debug) Log(ctx context.Context, req *proto.LogRequest, stream proto.Debug_LogStream) error { | ||||
|  | ||||
| 	var options []log.ReadOption | ||||
|  | ||||
|   | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -5,7 +5,7 @@ package debug | ||||
|  | ||||
| import ( | ||||
| 	fmt "fmt" | ||||
| 	proto "github.com/golang/protobuf/proto" | ||||
| 	proto "google.golang.org/protobuf/proto" | ||||
| 	math "math" | ||||
| ) | ||||
|  | ||||
| @@ -21,12 +21,6 @@ var _ = proto.Marshal | ||||
| var _ = fmt.Errorf | ||||
| var _ = math.Inf | ||||
|  | ||||
| // This is a compile-time assertion to ensure that this generated file | ||||
| // is compatible with the proto package it is being compiled against. | ||||
| // A compilation error at this line likely means your copy of the | ||||
| // proto package needs to be updated. | ||||
| const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package | ||||
|  | ||||
| // Reference imports to suppress errors if they are not otherwise used. | ||||
| var _ api.Endpoint | ||||
| var _ context.Context | ||||
| @@ -46,6 +40,7 @@ type DebugService interface { | ||||
| 	Health(ctx context.Context, in *HealthRequest, opts ...client.CallOption) (*HealthResponse, error) | ||||
| 	Stats(ctx context.Context, in *StatsRequest, opts ...client.CallOption) (*StatsResponse, error) | ||||
| 	Trace(ctx context.Context, in *TraceRequest, opts ...client.CallOption) (*TraceResponse, error) | ||||
| 	MessageBus(ctx context.Context, opts ...client.CallOption) (Debug_MessageBusService, error) | ||||
| } | ||||
|  | ||||
| type debugService struct { | ||||
| @@ -76,6 +71,7 @@ type Debug_LogService interface { | ||||
| 	Context() context.Context | ||||
| 	SendMsg(interface{}) error | ||||
| 	RecvMsg(interface{}) error | ||||
| 	CloseSend() error | ||||
| 	Close() error | ||||
| 	Recv() (*Record, error) | ||||
| } | ||||
| @@ -84,6 +80,10 @@ type debugServiceLog struct { | ||||
| 	stream client.Stream | ||||
| } | ||||
|  | ||||
| func (x *debugServiceLog) CloseSend() error { | ||||
| 	return x.stream.CloseSend() | ||||
| } | ||||
|  | ||||
| func (x *debugServiceLog) Close() error { | ||||
| 	return x.stream.Close() | ||||
| } | ||||
| @@ -139,6 +139,62 @@ func (c *debugService) Trace(ctx context.Context, in *TraceRequest, opts ...clie | ||||
| 	return out, nil | ||||
| } | ||||
|  | ||||
| func (c *debugService) MessageBus(ctx context.Context, opts ...client.CallOption) (Debug_MessageBusService, error) { | ||||
| 	req := c.c.NewRequest(c.name, "Debug.MessageBus", &BusMsg{}) | ||||
| 	stream, err := c.c.Stream(ctx, req, opts...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &debugServiceMessageBus{stream}, nil | ||||
| } | ||||
|  | ||||
| type Debug_MessageBusService interface { | ||||
| 	Context() context.Context | ||||
| 	SendMsg(interface{}) error | ||||
| 	RecvMsg(interface{}) error | ||||
| 	CloseSend() error | ||||
| 	Close() error | ||||
| 	Send(*BusMsg) error | ||||
| 	Recv() (*BusMsg, error) | ||||
| } | ||||
|  | ||||
| type debugServiceMessageBus struct { | ||||
| 	stream client.Stream | ||||
| } | ||||
|  | ||||
| func (x *debugServiceMessageBus) CloseSend() error { | ||||
| 	return x.stream.CloseSend() | ||||
| } | ||||
|  | ||||
| func (x *debugServiceMessageBus) Close() error { | ||||
| 	return x.stream.Close() | ||||
| } | ||||
|  | ||||
| func (x *debugServiceMessageBus) Context() context.Context { | ||||
| 	return x.stream.Context() | ||||
| } | ||||
|  | ||||
| func (x *debugServiceMessageBus) SendMsg(m interface{}) error { | ||||
| 	return x.stream.Send(m) | ||||
| } | ||||
|  | ||||
| func (x *debugServiceMessageBus) RecvMsg(m interface{}) error { | ||||
| 	return x.stream.Recv(m) | ||||
| } | ||||
|  | ||||
| func (x *debugServiceMessageBus) Send(m *BusMsg) error { | ||||
| 	return x.stream.Send(m) | ||||
| } | ||||
|  | ||||
| func (x *debugServiceMessageBus) Recv() (*BusMsg, error) { | ||||
| 	m := new(BusMsg) | ||||
| 	err := x.stream.Recv(m) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return m, nil | ||||
| } | ||||
|  | ||||
| // Server API for Debug service | ||||
|  | ||||
| type DebugHandler interface { | ||||
| @@ -146,6 +202,7 @@ type DebugHandler interface { | ||||
| 	Health(context.Context, *HealthRequest, *HealthResponse) error | ||||
| 	Stats(context.Context, *StatsRequest, *StatsResponse) error | ||||
| 	Trace(context.Context, *TraceRequest, *TraceResponse) error | ||||
| 	MessageBus(context.Context, Debug_MessageBusStream) error | ||||
| } | ||||
|  | ||||
| func RegisterDebugHandler(s server.Server, hdlr DebugHandler, opts ...server.HandlerOption) error { | ||||
| @@ -154,6 +211,7 @@ func RegisterDebugHandler(s server.Server, hdlr DebugHandler, opts ...server.Han | ||||
| 		Health(ctx context.Context, in *HealthRequest, out *HealthResponse) error | ||||
| 		Stats(ctx context.Context, in *StatsRequest, out *StatsResponse) error | ||||
| 		Trace(ctx context.Context, in *TraceRequest, out *TraceResponse) error | ||||
| 		MessageBus(ctx context.Context, stream server.Stream) error | ||||
| 	} | ||||
| 	type Debug struct { | ||||
| 		debug | ||||
| @@ -217,3 +275,48 @@ func (h *debugHandler) Stats(ctx context.Context, in *StatsRequest, out *StatsRe | ||||
| func (h *debugHandler) Trace(ctx context.Context, in *TraceRequest, out *TraceResponse) error { | ||||
| 	return h.DebugHandler.Trace(ctx, in, out) | ||||
| } | ||||
|  | ||||
| func (h *debugHandler) MessageBus(ctx context.Context, stream server.Stream) error { | ||||
| 	return h.DebugHandler.MessageBus(ctx, &debugMessageBusStream{stream}) | ||||
| } | ||||
|  | ||||
| type Debug_MessageBusStream interface { | ||||
| 	Context() context.Context | ||||
| 	SendMsg(interface{}) error | ||||
| 	RecvMsg(interface{}) error | ||||
| 	Close() error | ||||
| 	Send(*BusMsg) error | ||||
| 	Recv() (*BusMsg, error) | ||||
| } | ||||
|  | ||||
| type debugMessageBusStream struct { | ||||
| 	stream server.Stream | ||||
| } | ||||
|  | ||||
| func (x *debugMessageBusStream) Close() error { | ||||
| 	return x.stream.Close() | ||||
| } | ||||
|  | ||||
| func (x *debugMessageBusStream) Context() context.Context { | ||||
| 	return x.stream.Context() | ||||
| } | ||||
|  | ||||
| func (x *debugMessageBusStream) SendMsg(m interface{}) error { | ||||
| 	return x.stream.Send(m) | ||||
| } | ||||
|  | ||||
| func (x *debugMessageBusStream) RecvMsg(m interface{}) error { | ||||
| 	return x.stream.Recv(m) | ||||
| } | ||||
|  | ||||
| func (x *debugMessageBusStream) Send(m *BusMsg) error { | ||||
| 	return x.stream.Send(m) | ||||
| } | ||||
|  | ||||
| func (x *debugMessageBusStream) Recv() (*BusMsg, error) { | ||||
| 	m := new(BusMsg) | ||||
| 	if err := x.stream.Recv(m); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return m, nil | ||||
| } | ||||
|   | ||||
| @@ -1,99 +1,107 @@ | ||||
| syntax = "proto3"; | ||||
|  | ||||
| package debug; | ||||
|  | ||||
| option go_package = "./proto;debug"; | ||||
|  | ||||
| // Compile this proto by running the following command in the debug directory: | ||||
| // protoc --proto_path=. --micro_out=. --go_out=:. proto/debug.proto | ||||
|  | ||||
| service Debug { | ||||
| 	rpc Log(LogRequest) returns (stream Record) {}; | ||||
| 	rpc Health(HealthRequest) returns (HealthResponse) {}; | ||||
| 	rpc Stats(StatsRequest) returns (StatsResponse) {}; | ||||
| 	rpc Trace(TraceRequest) returns (TraceResponse) {}; | ||||
|   rpc Log(LogRequest) returns (stream Record) {}; | ||||
|   rpc Health(HealthRequest) returns (HealthResponse) {}; | ||||
|   rpc Stats(StatsRequest) returns (StatsResponse) {}; | ||||
|   rpc Trace(TraceRequest) returns (TraceResponse) {}; | ||||
|   rpc MessageBus(stream BusMsg) returns (stream BusMsg) {}; | ||||
| } | ||||
|  | ||||
| message BusMsg { string msg = 1; } | ||||
|  | ||||
| message HealthRequest { | ||||
| 	// optional service name | ||||
| 	string service = 1; | ||||
|   // optional service name | ||||
|   string service = 1; | ||||
| } | ||||
|  | ||||
| message HealthResponse { | ||||
| 	// default: ok | ||||
| 	string status = 1; | ||||
|   // default: ok | ||||
|   string status = 1; | ||||
| } | ||||
|  | ||||
| message StatsRequest { | ||||
| 	// optional service name | ||||
| 	string service = 1; | ||||
|   // optional service name | ||||
|   string service = 1; | ||||
| } | ||||
|  | ||||
| message StatsResponse { | ||||
| 	// timestamp of recording | ||||
| 	uint64 timestamp = 1; | ||||
| 	// unix timestamp | ||||
| 	uint64 started = 2; | ||||
| 	// in seconds | ||||
| 	uint64 uptime = 3; | ||||
| 	// in bytes | ||||
| 	uint64 memory = 4; | ||||
| 	// num threads | ||||
| 	uint64 threads = 5; | ||||
| 	// total gc in nanoseconds | ||||
| 	uint64 gc = 6; | ||||
| 	// total number of requests | ||||
| 	uint64 requests = 7; | ||||
| 	// total number of errors | ||||
| 	uint64 errors = 8; | ||||
|   // timestamp of recording | ||||
|   uint64 timestamp = 1; | ||||
|   // unix timestamp | ||||
|   uint64 started = 2; | ||||
|   // in seconds | ||||
|   uint64 uptime = 3; | ||||
|   // in bytes | ||||
|   uint64 memory = 4; | ||||
|   // num threads | ||||
|   uint64 threads = 5; | ||||
|   // total gc in nanoseconds | ||||
|   uint64 gc = 6; | ||||
|   // total number of requests | ||||
|   uint64 requests = 7; | ||||
|   // total number of errors | ||||
|   uint64 errors = 8; | ||||
| } | ||||
|  | ||||
| // LogRequest requests service logs | ||||
| message LogRequest { | ||||
| 	// service to request logs for | ||||
| 	string service = 1; | ||||
| 	// stream records continuously | ||||
| 	bool stream = 2; | ||||
| 	// count of records to request | ||||
| 	int64 count = 3; | ||||
| 	// relative time in seconds | ||||
| 	// before the current time | ||||
| 	// from which to show logs | ||||
| 	int64 since = 4; | ||||
|   // service to request logs for | ||||
|   string service = 1; | ||||
|   // stream records continuously | ||||
|   bool stream = 2; | ||||
|   // count of records to request | ||||
|   int64 count = 3; | ||||
|   // relative time in seconds | ||||
|   // before the current time | ||||
|   // from which to show logs | ||||
|   int64 since = 4; | ||||
| } | ||||
|  | ||||
| // Record is service log record | ||||
| // Also used as default basic message type to test requests. | ||||
| message Record { | ||||
|     // timestamp of log record | ||||
|     int64 timestamp = 1; | ||||
|     // record metadata | ||||
|     map<string,string> metadata = 2; | ||||
|     // message | ||||
|     string message = 3; | ||||
|   // timestamp of log record | ||||
|   int64 timestamp = 1; | ||||
|   // record metadata | ||||
|   map<string, string> metadata = 2; | ||||
|   // message | ||||
|   string message = 3; | ||||
| } | ||||
|  | ||||
| message TraceRequest { | ||||
| 	// trace id to retrieve | ||||
| 	string id = 1; | ||||
| } | ||||
|  | ||||
| message TraceResponse { | ||||
| 	repeated Span spans = 1; | ||||
|   // trace id to retrieve | ||||
|   string id = 1; | ||||
| } | ||||
|  | ||||
| message TraceResponse { repeated Span spans = 1; } | ||||
|  | ||||
| enum SpanType { | ||||
|     INBOUND = 0; | ||||
|     OUTBOUND = 1; | ||||
|   INBOUND = 0; | ||||
|   OUTBOUND = 1; | ||||
| } | ||||
|  | ||||
| message Span { | ||||
| 	// the trace id | ||||
| 	string trace = 1; | ||||
| 	// id of the span | ||||
| 	string id = 2; | ||||
| 	// parent span | ||||
| 	string parent = 3; | ||||
| 	// name of the resource | ||||
| 	string name = 4; | ||||
| 	// time of start in nanoseconds | ||||
| 	uint64 started = 5; | ||||
| 	// duration of the execution in nanoseconds | ||||
| 	uint64 duration = 6; | ||||
| 	// associated metadata | ||||
| 	map<string,string> metadata = 7; | ||||
| 	SpanType type = 8; | ||||
|   // the trace id | ||||
|   string trace = 1; | ||||
|   // id of the span | ||||
|   string id = 2; | ||||
|   // parent span | ||||
|   string parent = 3; | ||||
|   // name of the resource | ||||
|   string name = 4; | ||||
|   // time of start in nanoseconds | ||||
|   uint64 started = 5; | ||||
|   // duration of the execution in nanoseconds | ||||
|   uint64 duration = 6; | ||||
|   // associated metadata | ||||
|   map<string, string> metadata = 7; | ||||
|   SpanType type = 8; | ||||
| } | ||||
|   | ||||
							
								
								
									
										21
									
								
								debug/trace/noop.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								debug/trace/noop.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| package trace | ||||
|  | ||||
| import "context" | ||||
|  | ||||
| type noop struct{} | ||||
|  | ||||
| func (n *noop) Init(...Option) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (n *noop) Start(ctx context.Context, name string) (context.Context, *Span) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (n *noop) Finish(*Span) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (n *noop) Read(...ReadOption) ([]*Span, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
| @@ -6,6 +6,12 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"go-micro.dev/v4/metadata" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// DefaultTracer is the default tracer. | ||||
| 	DefaultTracer = NewTracer() | ||||
| ) | ||||
|  | ||||
| // Tracer is an interface for distributed tracing. | ||||
| @@ -48,52 +54,29 @@ type Span struct { | ||||
| 	Type SpanType | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	traceIDKey = "Micro-Trace-Id" | ||||
| 	spanIDKey  = "Micro-Span-Id" | ||||
| ) | ||||
|  | ||||
| // FromContext returns a span from context. | ||||
| func FromContext(ctx context.Context) (traceID string, parentSpanID string, isFound bool) { | ||||
| 	traceID, traceOk := metadata.Get(ctx, traceIDKey) | ||||
| 	microID, microOk := metadata.Get(ctx, "Micro-Id") | ||||
| 	traceID, traceOk := metadata.Get(ctx, headers.TraceIDKey) | ||||
| 	microID, microOk := metadata.Get(ctx, headers.ID) | ||||
|  | ||||
| 	if !traceOk && !microOk { | ||||
| 		isFound = false | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !traceOk { | ||||
| 		traceID = microID | ||||
| 	} | ||||
| 	parentSpanID, ok := metadata.Get(ctx, spanIDKey) | ||||
|  | ||||
| 	parentSpanID, ok := metadata.Get(ctx, headers.SpanID) | ||||
|  | ||||
| 	return traceID, parentSpanID, ok | ||||
| } | ||||
|  | ||||
| // ToContext saves the trace and span ids in the context. | ||||
| func ToContext(ctx context.Context, traceID, parentSpanID string) context.Context { | ||||
| 	return metadata.MergeContext(ctx, map[string]string{ | ||||
| 		traceIDKey: traceID, | ||||
| 		spanIDKey:  parentSpanID, | ||||
| 		headers.TraceIDKey: traceID, | ||||
| 		headers.SpanID:     parentSpanID, | ||||
| 	}, true) | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	DefaultTracer Tracer = NewTracer() | ||||
| ) | ||||
|  | ||||
| type noop struct{} | ||||
|  | ||||
| func (n *noop) Init(...Option) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (n *noop) Start(ctx context.Context, name string) (context.Context, *Span) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (n *noop) Finish(*Span) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (n *noop) Read(...ReadOption) ([]*Span, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|   | ||||
							
								
								
									
										4
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.mod
									
									
									
									
									
								
							| @@ -1,6 +1,6 @@ | ||||
| module go-micro.dev/v4 | ||||
|  | ||||
| go 1.17 | ||||
| go 1.18 | ||||
|  | ||||
| require ( | ||||
| 	github.com/bitly/go-simplejson v0.5.0 | ||||
| @@ -69,7 +69,7 @@ require ( | ||||
| 	github.com/sirupsen/logrus v1.7.0 // indirect | ||||
| 	github.com/xanzy/ssh-agent v0.3.0 // indirect | ||||
| 	go.opencensus.io v0.22.3 // indirect | ||||
| 	golang.org/x/sys v0.0.0-20210502180810-71e4cd670f79 // indirect | ||||
| 	golang.org/x/sys v0.0.0-20210510120138-977fb7262007 // indirect | ||||
| 	golang.org/x/text v0.3.6 // indirect | ||||
| 	gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect | ||||
| 	gopkg.in/warnings.v0 v0.1.2 // indirect | ||||
|   | ||||
							
								
								
									
										5
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.sum
									
									
									
									
									
								
							| @@ -507,7 +507,6 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1 | ||||
| github.com/transip/gotransip/v6 v6.2.0/go.mod h1:pQZ36hWWRahCUXkFWlx9Hs711gLd8J4qdgLdRzmtY+g= | ||||
| github.com/uber-go/atomic v1.3.2/go.mod h1:/Ct5t2lcmbJ4OSe/waGBoaVvVqtO0bmtfVNex1PFV8g= | ||||
| github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= | ||||
| github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA= | ||||
| github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= | ||||
| github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M= | ||||
| github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= | ||||
| @@ -689,9 +688,9 @@ golang.org/x/sys v0.0.0-20210216224549-f992740a1bac/go.mod h1:h1NjWce9XRLGQEsW7w | ||||
| golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210502180810-71e4cd670f79 h1:RX8C8PRZc2hTIod4ds8ij+/4RQX3AqhYj3uOHmyaz4E= | ||||
| golang.org/x/sys v0.0.0-20210502180810-71e4cd670f79/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210510120138-977fb7262007 h1:gG67DSER+11cZvqIMb8S8bt0vZtiN6xWYARwirrOSfE= | ||||
| golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/term v0.0.0-20201113234701-d7a72108b828/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= | ||||
| golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= | ||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= | ||||
|   | ||||
| @@ -32,6 +32,7 @@ func (l *defaultLogger) Init(opts ...Option) error { | ||||
| 	for _, o := range opts { | ||||
| 		o(&l.opts) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -42,6 +43,7 @@ func (l *defaultLogger) String() string { | ||||
| func (l *defaultLogger) Fields(fields map[string]interface{}) Logger { | ||||
| 	l.Lock() | ||||
| 	nfields := make(map[string]interface{}, len(l.opts.Fields)) | ||||
|  | ||||
| 	for k, v := range l.opts.Fields { | ||||
| 		nfields[k] = v | ||||
| 	} | ||||
| @@ -65,6 +67,7 @@ func copyFields(src map[string]interface{}) map[string]interface{} { | ||||
| 	for k, v := range src { | ||||
| 		dst[k] = v | ||||
| 	} | ||||
|  | ||||
| 	return dst | ||||
| } | ||||
|  | ||||
| @@ -85,10 +88,13 @@ func logCallerfilePath(loggingFilePath string) string { | ||||
| 	if idx == -1 { | ||||
| 		return loggingFilePath | ||||
| 	} | ||||
|  | ||||
| 	idx = strings.LastIndexByte(loggingFilePath[:idx], '/') | ||||
|  | ||||
| 	if idx == -1 { | ||||
| 		return loggingFilePath | ||||
| 	} | ||||
|  | ||||
| 	return loggingFilePath[idx+1:] | ||||
| } | ||||
|  | ||||
| @@ -121,6 +127,7 @@ func (l *defaultLogger) Log(level Level, v ...interface{}) { | ||||
| 	} | ||||
|  | ||||
| 	sort.Strings(keys) | ||||
|  | ||||
| 	metadata := "" | ||||
|  | ||||
| 	for _, k := range keys { | ||||
| @@ -162,6 +169,7 @@ func (l *defaultLogger) Logf(level Level, format string, v ...interface{}) { | ||||
| 	} | ||||
|  | ||||
| 	sort.Strings(keys) | ||||
|  | ||||
| 	metadata := "" | ||||
|  | ||||
| 	for _, k := range keys { | ||||
| @@ -177,9 +185,11 @@ func (l *defaultLogger) Logf(level Level, format string, v ...interface{}) { | ||||
| func (l *defaultLogger) Options() Options { | ||||
| 	// not guard against options Context values | ||||
| 	l.RLock() | ||||
| 	defer l.RUnlock() | ||||
|  | ||||
| 	opts := l.opts | ||||
| 	opts.Fields = copyFields(l.opts.Fields) | ||||
| 	l.RUnlock() | ||||
|  | ||||
| 	return opts | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										1
									
								
								micro.go
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								micro.go
									
									
									
									
									
								
							| @@ -62,6 +62,7 @@ func NewEvent(topic string, c client.Client) Event { | ||||
| 	if c == nil { | ||||
| 		c = client.NewClient() | ||||
| 	} | ||||
|  | ||||
| 	return &event{c, topic} | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										4
									
								
								registry/cache/cache.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								registry/cache/cache.go
									
									
									
									
										vendored
									
									
								
							| @@ -7,10 +7,11 @@ import ( | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"golang.org/x/sync/singleflight" | ||||
|  | ||||
| 	log "go-micro.dev/v4/logger" | ||||
| 	"go-micro.dev/v4/registry" | ||||
| 	util "go-micro.dev/v4/util/registry" | ||||
| 	"golang.org/x/sync/singleflight" | ||||
| ) | ||||
|  | ||||
| // Cache is the registry cache interface. | ||||
| @@ -464,6 +465,7 @@ func (c *cache) String() string { | ||||
| // New returns a new cache. | ||||
| func New(r registry.Registry, opts ...Option) Cache { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
|  | ||||
| 	options := Options{ | ||||
| 		TTL:    DefaultTTL, | ||||
| 		Logger: log.DefaultLogger, | ||||
|   | ||||
| @@ -1,8 +1,11 @@ | ||||
| package selector | ||||
|  | ||||
| import ( | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/pkg/errors" | ||||
|  | ||||
| 	"go-micro.dev/v4/registry" | ||||
| 	"go-micro.dev/v4/registry/cache" | ||||
| ) | ||||
| @@ -10,19 +13,25 @@ import ( | ||||
| type registrySelector struct { | ||||
| 	so Options | ||||
| 	rc cache.Cache | ||||
| 	mu sync.RWMutex | ||||
| } | ||||
|  | ||||
| func (c *registrySelector) newCache() cache.Cache { | ||||
| 	opts := make([]cache.Option, 0, 1) | ||||
|  | ||||
| 	if c.so.Context != nil { | ||||
| 		if t, ok := c.so.Context.Value("selector_ttl").(time.Duration); ok { | ||||
| 			opts = append(opts, cache.WithTTL(t)) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return cache.New(c.so.Registry, opts...) | ||||
| } | ||||
|  | ||||
| func (c *registrySelector) Init(opts ...Option) error { | ||||
| 	c.mu.Lock() | ||||
| 	defer c.mu.Unlock() | ||||
|  | ||||
| 	for _, o := range opts { | ||||
| 		o(&c.so) | ||||
| 	} | ||||
| @@ -38,6 +47,9 @@ func (c *registrySelector) Options() Options { | ||||
| } | ||||
|  | ||||
| func (c *registrySelector) Select(service string, opts ...SelectOption) (Next, error) { | ||||
| 	c.mu.RLock() | ||||
| 	defer c.mu.RUnlock() | ||||
|  | ||||
| 	sopts := SelectOptions{ | ||||
| 		Strategy: c.so.Strategy, | ||||
| 	} | ||||
| @@ -51,9 +63,10 @@ func (c *registrySelector) Select(service string, opts ...SelectOption) (Next, e | ||||
| 	// if that fails go directly to the registry | ||||
| 	services, err := c.rc.GetService(service) | ||||
| 	if err != nil { | ||||
| 		if err == registry.ErrNotFound { | ||||
| 		if errors.Is(err, registry.ErrNotFound) { | ||||
| 			return nil, ErrNotFound | ||||
| 		} | ||||
|  | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| @@ -87,6 +100,7 @@ func (c *registrySelector) String() string { | ||||
| 	return "registry" | ||||
| } | ||||
|  | ||||
| // NewSelector creates a new default selector. | ||||
| func NewSelector(opts ...Option) Selector { | ||||
| 	sopts := Options{ | ||||
| 		Strategy: Random, | ||||
|   | ||||
| @@ -6,12 +6,13 @@ import ( | ||||
| ) | ||||
|  | ||||
| type serverKey struct{} | ||||
| type wgKey struct{} | ||||
|  | ||||
| func wait(ctx context.Context) *sync.WaitGroup { | ||||
| 	if ctx == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	wg, ok := ctx.Value("wait").(*sync.WaitGroup) | ||||
| 	wg, ok := ctx.Value(wgKey{}).(*sync.WaitGroup) | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| 	} | ||||
|   | ||||
| @@ -20,7 +20,7 @@ type RouterOptions struct { | ||||
|  | ||||
| type RouterOption func(o *RouterOptions) | ||||
|  | ||||
| func newRouterOptions(opt ...RouterOption) RouterOptions { | ||||
| func NewRouterOptions(opt ...RouterOption) RouterOptions { | ||||
| 	opts := RouterOptions{ | ||||
| 		Logger: logger.DefaultLogger, | ||||
| 	} | ||||
| @@ -74,7 +74,8 @@ type Options struct { | ||||
| 	Context context.Context | ||||
| } | ||||
|  | ||||
| func newOptions(opt ...Option) Options { | ||||
| // NewOptions creates new server options. | ||||
| func NewOptions(opt ...Option) Options { | ||||
| 	opts := Options{ | ||||
| 		Codecs:           make(map[string]codec.NewCodec), | ||||
| 		Metadata:         map[string]string{}, | ||||
| @@ -275,7 +276,7 @@ func Wait(wg *sync.WaitGroup) Option { | ||||
| 		if wg == nil { | ||||
| 			wg = new(sync.WaitGroup) | ||||
| 		} | ||||
| 		o.Context = context.WithValue(o.Context, "wait", wg) | ||||
| 		o.Context = context.WithValue(o.Context, wgKey{}, wg) | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
|  | ||||
| 	"github.com/oxtoacart/bpool" | ||||
| 	"github.com/pkg/errors" | ||||
|  | ||||
| 	"go-micro.dev/v4/codec" | ||||
| 	raw "go-micro.dev/v4/codec/bytes" | ||||
| 	"go-micro.dev/v4/codec/grpc" | ||||
| @@ -14,6 +15,7 @@ import ( | ||||
| 	"go-micro.dev/v4/codec/proto" | ||||
| 	"go-micro.dev/v4/codec/protorpc" | ||||
| 	"go-micro.dev/v4/transport" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| ) | ||||
|  | ||||
| type rpcCodec struct { | ||||
| @@ -36,6 +38,7 @@ type readWriteCloser struct { | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	// DefaultContentType is the default codec content type. | ||||
| 	DefaultContentType = "application/protobuf" | ||||
|  | ||||
| 	DefaultCodecs = map[string]codec.NewCodec{ | ||||
| @@ -65,12 +68,14 @@ var ( | ||||
| func (rwc *readWriteCloser) Read(p []byte) (n int, err error) { | ||||
| 	rwc.RLock() | ||||
| 	defer rwc.RUnlock() | ||||
|  | ||||
| 	return rwc.rbuf.Read(p) | ||||
| } | ||||
|  | ||||
| func (rwc *readWriteCloser) Write(p []byte) (n int, err error) { | ||||
| 	rwc.Lock() | ||||
| 	defer rwc.Unlock() | ||||
|  | ||||
| 	return rwc.wbuf.Write(p) | ||||
| } | ||||
|  | ||||
| @@ -82,6 +87,7 @@ func getHeader(hdr string, md map[string]string) string { | ||||
| 	if hd := md[hdr]; len(hd) > 0 { | ||||
| 		return hd | ||||
| 	} | ||||
|  | ||||
| 	return md["X-"+hdr] | ||||
| } | ||||
|  | ||||
| @@ -90,14 +96,15 @@ func getHeaders(m *codec.Message) { | ||||
| 		if len(v) > 0 { | ||||
| 			return v | ||||
| 		} | ||||
|  | ||||
| 		return m.Header[hdr] | ||||
| 	} | ||||
|  | ||||
| 	m.Id = set(m.Id, "Micro-Id") | ||||
| 	m.Error = set(m.Error, "Micro-Error") | ||||
| 	m.Endpoint = set(m.Endpoint, "Micro-Endpoint") | ||||
| 	m.Method = set(m.Method, "Micro-Method") | ||||
| 	m.Target = set(m.Target, "Micro-Service") | ||||
| 	m.Id = set(m.Id, headers.ID) | ||||
| 	m.Error = set(m.Error, headers.Error) | ||||
| 	m.Endpoint = set(m.Endpoint, headers.Endpoint) | ||||
| 	m.Method = set(m.Method, headers.Method) | ||||
| 	m.Target = set(m.Target, headers.Request) | ||||
|  | ||||
| 	// TODO: remove this cruft | ||||
| 	if len(m.Endpoint) == 0 { | ||||
| @@ -110,26 +117,27 @@ func setHeaders(m, r *codec.Message) { | ||||
| 		if len(v) == 0 { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		m.Header[hdr] = v | ||||
| 		m.Header["X-"+hdr] = v | ||||
| 	} | ||||
|  | ||||
| 	// set headers | ||||
| 	set("Micro-Id", r.Id) | ||||
| 	set("Micro-Service", r.Target) | ||||
| 	set("Micro-Method", r.Method) | ||||
| 	set("Micro-Endpoint", r.Endpoint) | ||||
| 	set("Micro-Error", r.Error) | ||||
| 	set(headers.ID, r.Id) | ||||
| 	set(headers.Request, r.Target) | ||||
| 	set(headers.Method, r.Method) | ||||
| 	set(headers.Endpoint, r.Endpoint) | ||||
| 	set(headers.Error, r.Error) | ||||
| } | ||||
|  | ||||
| // setupProtocol sets up the old protocol. | ||||
| func setupProtocol(msg *transport.Message) codec.NewCodec { | ||||
| 	service := getHeader("Micro-Service", msg.Header) | ||||
| 	method := getHeader("Micro-Method", msg.Header) | ||||
| 	endpoint := getHeader("Micro-Endpoint", msg.Header) | ||||
| 	protocol := getHeader("Micro-Protocol", msg.Header) | ||||
| 	target := getHeader("Micro-Target", msg.Header) | ||||
| 	topic := getHeader("Micro-Topic", msg.Header) | ||||
| 	service := getHeader(headers.Request, msg.Header) | ||||
| 	method := getHeader(headers.Method, msg.Header) | ||||
| 	endpoint := getHeader(headers.Endpoint, msg.Header) | ||||
| 	protocol := getHeader(headers.Protocol, msg.Header) | ||||
| 	target := getHeader(headers.Target, msg.Header) | ||||
| 	topic := getHeader(headers.Message, msg.Header) | ||||
|  | ||||
| 	// if the protocol exists (mucp) do nothing | ||||
| 	if len(protocol) > 0 { | ||||
| @@ -153,18 +161,18 @@ func setupProtocol(msg *transport.Message) codec.NewCodec { | ||||
|  | ||||
| 	// no method then set to endpoint | ||||
| 	if len(method) == 0 { | ||||
| 		msg.Header["Micro-Method"] = endpoint | ||||
| 		msg.Header[headers.Method] = endpoint | ||||
| 	} | ||||
|  | ||||
| 	// no endpoint then set to method | ||||
| 	if len(endpoint) == 0 { | ||||
| 		msg.Header["Micro-Endpoint"] = method | ||||
| 		msg.Header[headers.Endpoint] = method | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCodec) codec.Codec { | ||||
| func newRPCCodec(req *transport.Message, socket transport.Socket, c codec.NewCodec) codec.Codec { | ||||
| 	rwc := &readWriteCloser{ | ||||
| 		rbuf: bufferPool.Get(), | ||||
| 		wbuf: bufferPool.Get(), | ||||
| @@ -185,7 +193,6 @@ func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCod | ||||
| 	case "grpc": | ||||
| 		// write the body | ||||
| 		rwc.rbuf.Write(req.Body) | ||||
| 		// set the protocol | ||||
| 		r.protocol = "grpc" | ||||
| 	default: | ||||
| 		// first is not preloaded | ||||
| @@ -197,7 +204,7 @@ func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCod | ||||
|  | ||||
| func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error { | ||||
| 	// the initial message | ||||
| 	m := codec.Message{ | ||||
| 	mmsg := codec.Message{ | ||||
| 		Header: c.req.Header, | ||||
| 		Body:   c.req.Body, | ||||
| 	} | ||||
| @@ -221,9 +228,9 @@ func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error { | ||||
| 		} | ||||
|  | ||||
| 		// set the message header | ||||
| 		m.Header = tm.Header | ||||
| 		mmsg.Header = tm.Header | ||||
| 		// set the message body | ||||
| 		m.Body = tm.Body | ||||
| 		mmsg.Body = tm.Body | ||||
|  | ||||
| 		// set req | ||||
| 		c.req = &tm | ||||
| @@ -248,20 +255,20 @@ func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error { | ||||
| 	} | ||||
|  | ||||
| 	// set some internal things | ||||
| 	getHeaders(&m) | ||||
| 	getHeaders(&mmsg) | ||||
|  | ||||
| 	// read header via codec | ||||
| 	if err := c.codec.ReadHeader(&m, codec.Request); err != nil { | ||||
| 	if err := c.codec.ReadHeader(&mmsg, codec.Request); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// fallback for 0.14 and older | ||||
| 	if len(m.Endpoint) == 0 { | ||||
| 		m.Endpoint = m.Method | ||||
| 	if len(mmsg.Endpoint) == 0 { | ||||
| 		mmsg.Endpoint = mmsg.Method | ||||
| 	} | ||||
|  | ||||
| 	// set message | ||||
| 	*r = m | ||||
| 	*r = mmsg | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @@ -315,7 +322,7 @@ func (c *rpcCodec) Write(r *codec.Message, b interface{}) error { | ||||
|  | ||||
| 		// write an error if it failed | ||||
| 		m.Error = errors.Wrapf(err, "Unable to encode body").Error() | ||||
| 		m.Header["Micro-Error"] = m.Error | ||||
| 		m.Header[headers.Error] = m.Error | ||||
| 		// no body to write | ||||
| 		if err := c.codec.Write(m, nil); err != nil { | ||||
| 			return err | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package server | ||||
| import ( | ||||
| 	"go-micro.dev/v4/broker" | ||||
| 	"go-micro.dev/v4/transport" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| ) | ||||
|  | ||||
| // event is a broker event we handle on the server transport. | ||||
| @@ -25,7 +26,7 @@ func (e *event) Error() error { | ||||
| } | ||||
|  | ||||
| func (e *event) Topic() string { | ||||
| 	return e.message.Header["Micro-Topic"] | ||||
| 	return e.message.Header[headers.Message] | ||||
| } | ||||
|  | ||||
| func newEvent(msg transport.Message) *event { | ||||
|   | ||||
							
								
								
									
										143
									
								
								server/rpc_events.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								server/rpc_events.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| package server | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"go-micro.dev/v4/broker" | ||||
| 	raw "go-micro.dev/v4/codec/bytes" | ||||
| 	log "go-micro.dev/v4/logger" | ||||
| 	"go-micro.dev/v4/metadata" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| ) | ||||
|  | ||||
| // HandleEvent handles inbound messages to the service directly. | ||||
| // These events are a result of registering to the topic with the service name. | ||||
| // TODO: handle requests from an event. We won't send a response. | ||||
| func (s *rpcServer) HandleEvent(e broker.Event) error { | ||||
| 	// formatting horrible cruft | ||||
| 	msg := e.Message() | ||||
|  | ||||
| 	if msg.Header == nil { | ||||
| 		msg.Header = make(map[string]string) | ||||
| 	} | ||||
|  | ||||
| 	contentType, ok := msg.Header["Content-Type"] | ||||
| 	if !ok || len(contentType) == 0 { | ||||
| 		msg.Header["Content-Type"] = DefaultContentType | ||||
| 		contentType = DefaultContentType | ||||
| 	} | ||||
|  | ||||
| 	cf, err := s.newCodec(contentType) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	header := make(map[string]string, len(msg.Header)) | ||||
| 	for k, v := range msg.Header { | ||||
| 		header[k] = v | ||||
| 	} | ||||
|  | ||||
| 	// create context | ||||
| 	ctx := metadata.NewContext(context.Background(), header) | ||||
|  | ||||
| 	// TODO: inspect message header for Micro-Service & Micro-Topic | ||||
| 	rpcMsg := &rpcMessage{ | ||||
| 		topic:       msg.Header[headers.Message], | ||||
| 		contentType: contentType, | ||||
| 		payload:     &raw.Frame{Data: msg.Body}, | ||||
| 		codec:       cf, | ||||
| 		header:      msg.Header, | ||||
| 		body:        msg.Body, | ||||
| 	} | ||||
|  | ||||
| 	// if the router is present then execute it | ||||
| 	r := Router(s.router) | ||||
| 	if s.opts.Router != nil { | ||||
| 		// create a wrapped function | ||||
| 		handler := s.opts.Router.ProcessMessage | ||||
|  | ||||
| 		// execute the wrapper for it | ||||
| 		for i := len(s.opts.SubWrappers); i > 0; i-- { | ||||
| 			handler = s.opts.SubWrappers[i-1](handler) | ||||
| 		} | ||||
|  | ||||
| 		// set the router | ||||
| 		r = rpcRouter{m: handler} | ||||
| 	} | ||||
|  | ||||
| 	return r.ProcessMessage(ctx, rpcMsg) | ||||
| } | ||||
|  | ||||
| func (s *rpcServer) NewSubscriber(topic string, sb interface{}, opts ...SubscriberOption) Subscriber { | ||||
| 	return s.router.NewSubscriber(topic, sb, opts...) | ||||
| } | ||||
|  | ||||
| func (s *rpcServer) Subscribe(sb Subscriber) error { | ||||
| 	s.Lock() | ||||
| 	defer s.Unlock() | ||||
|  | ||||
| 	sub, ok := sb.(*subscriber) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("invalid subscriber: expected *subscriber") | ||||
| 	} | ||||
| 	if len(sub.handlers) == 0 { | ||||
| 		return fmt.Errorf("invalid subscriber: no handler functions") | ||||
| 	} | ||||
|  | ||||
| 	if err := validateSubscriber(sub); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// append to subscribers | ||||
| 	// subs := s.subscribers[sub.Topic()] | ||||
| 	// subs = append(subs, sub) | ||||
| 	// router.subscribers[sub.Topic()] = subs | ||||
|  | ||||
| 	s.subscribers[sb] = nil | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // subscribeServer will subscribe the server to the topic with its own name. | ||||
| func (s *rpcServer) subscribeServer(config Options) error { | ||||
| 	if s.opts.Router != nil { | ||||
| 		sub, err := s.opts.Broker.Subscribe(config.Name, s.HandleEvent) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		// Save the subscriber | ||||
| 		s.subscriber = sub | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // reSubscribe itterates over subscribers and re-subscribes then. | ||||
| func (s *rpcServer) reSubscribe(config Options) error { | ||||
| 	for sb := range s.subscribers { | ||||
| 		var opts []broker.SubscribeOption | ||||
| 		if queue := sb.Options().Queue; len(queue) > 0 { | ||||
| 			opts = append(opts, broker.Queue(queue)) | ||||
| 		} | ||||
|  | ||||
| 		if ctx := sb.Options().Context; ctx != nil { | ||||
| 			opts = append(opts, broker.SubscribeContext(ctx)) | ||||
| 		} | ||||
|  | ||||
| 		if !sb.Options().AutoAck { | ||||
| 			opts = append(opts, broker.DisableAutoAck()) | ||||
| 		} | ||||
|  | ||||
| 		config.Logger.Logf(log.InfoLevel, "Subscribing to topic: %s", sb.Topic()) | ||||
| 		sub, err := config.Broker.Subscribe(sb.Topic(), s.HandleEvent, opts...) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		s.subscribers[sb] = []broker.Subscriber{sub} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @@ -6,14 +6,14 @@ import ( | ||||
| 	"go-micro.dev/v4/registry" | ||||
| ) | ||||
|  | ||||
| type rpcHandler struct { | ||||
| type RpcHandler struct { | ||||
| 	name      string | ||||
| 	handler   interface{} | ||||
| 	endpoints []*registry.Endpoint | ||||
| 	opts      HandlerOptions | ||||
| } | ||||
|  | ||||
| func newRpcHandler(handler interface{}, opts ...HandlerOption) Handler { | ||||
| func NewRpcHandler(handler interface{}, opts ...HandlerOption) Handler { | ||||
| 	options := HandlerOptions{ | ||||
| 		Metadata: make(map[string]map[string]string), | ||||
| 	} | ||||
| @@ -40,7 +40,7 @@ func newRpcHandler(handler interface{}, opts ...HandlerOption) Handler { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return &rpcHandler{ | ||||
| 	return &RpcHandler{ | ||||
| 		name:      name, | ||||
| 		handler:   handler, | ||||
| 		endpoints: endpoints, | ||||
| @@ -48,18 +48,18 @@ func newRpcHandler(handler interface{}, opts ...HandlerOption) Handler { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *rpcHandler) Name() string { | ||||
| func (r *RpcHandler) Name() string { | ||||
| 	return r.name | ||||
| } | ||||
|  | ||||
| func (r *rpcHandler) Handler() interface{} { | ||||
| func (r *RpcHandler) Handler() interface{} { | ||||
| 	return r.handler | ||||
| } | ||||
|  | ||||
| func (r *rpcHandler) Endpoints() []*registry.Endpoint { | ||||
| func (r *RpcHandler) Endpoints() []*registry.Endpoint { | ||||
| 	return r.endpoints | ||||
| } | ||||
|  | ||||
| func (r *rpcHandler) Options() HandlerOptions { | ||||
| func (r *RpcHandler) Options() HandlerOptions { | ||||
| 	return r.opts | ||||
| } | ||||
|   | ||||
							
								
								
									
										101
									
								
								server/rpc_helper.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								server/rpc_helper.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,101 @@ | ||||
| package server | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
|  | ||||
| 	"go-micro.dev/v4/codec" | ||||
| 	"go-micro.dev/v4/registry" | ||||
| ) | ||||
|  | ||||
| // setRegistered will set the service as registered safely. | ||||
| func (s *rpcServer) setRegistered(b bool) { | ||||
| 	s.Lock() | ||||
| 	defer s.Unlock() | ||||
|  | ||||
| 	s.registered = b | ||||
| } | ||||
|  | ||||
| // isRegistered will check if the service has already been registered. | ||||
| func (s *rpcServer) isRegistered() bool { | ||||
| 	s.RLock() | ||||
| 	defer s.RUnlock() | ||||
|  | ||||
| 	return s.registered | ||||
| } | ||||
|  | ||||
| // setStarted will set started state safely. | ||||
| func (s *rpcServer) setStarted(b bool) { | ||||
| 	s.Lock() | ||||
| 	defer s.Unlock() | ||||
|  | ||||
| 	s.started = b | ||||
| } | ||||
|  | ||||
| // isStarted will check if the service has already been started. | ||||
| func (s *rpcServer) isStarted() bool { | ||||
| 	s.RLock() | ||||
| 	defer s.RUnlock() | ||||
|  | ||||
| 	return s.started | ||||
| } | ||||
|  | ||||
| // setWg will set the waitgroup safely. | ||||
| func (s *rpcServer) setWg(wg *sync.WaitGroup) { | ||||
| 	s.Lock() | ||||
| 	defer s.Unlock() | ||||
|  | ||||
| 	s.wg = wg | ||||
| } | ||||
|  | ||||
| // getWaitgroup returns the global waitgroup safely. | ||||
| func (s *rpcServer) getWg() *sync.WaitGroup { | ||||
| 	s.RLock() | ||||
| 	defer s.RUnlock() | ||||
|  | ||||
| 	return s.wg | ||||
| } | ||||
|  | ||||
| // setOptsAddr will set the address in the service options safely. | ||||
| func (s *rpcServer) setOptsAddr(addr string) { | ||||
| 	s.Lock() | ||||
| 	defer s.Unlock() | ||||
|  | ||||
| 	s.opts.Address = addr | ||||
| } | ||||
|  | ||||
| func (s *rpcServer) getCachedService() *registry.Service { | ||||
| 	s.RLock() | ||||
| 	defer s.RUnlock() | ||||
|  | ||||
| 	return s.rsvc | ||||
| } | ||||
|  | ||||
| func (s *rpcServer) Options() Options { | ||||
| 	s.RLock() | ||||
| 	defer s.RUnlock() | ||||
|  | ||||
| 	return s.opts | ||||
| } | ||||
|  | ||||
| // swapAddr swaps the address found in the config and the transport address. | ||||
| func (s *rpcServer) swapAddr(config Options, addr string) string { | ||||
| 	s.Lock() | ||||
| 	defer s.Unlock() | ||||
|  | ||||
| 	a := config.Address | ||||
| 	s.opts.Address = addr | ||||
| 	return a | ||||
| } | ||||
|  | ||||
| func (s *rpcServer) newCodec(contentType string) (codec.NewCodec, error) { | ||||
| 	if cf, ok := s.opts.Codecs[contentType]; ok { | ||||
| 		return cf, nil | ||||
| 	} | ||||
|  | ||||
| 	if cf, ok := DefaultCodecs[contentType]; ok { | ||||
| 		return cf, nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, fmt.Errorf("unsupported Content-Type: %s", contentType) | ||||
| } | ||||
| @@ -1,11 +1,5 @@ | ||||
| package server | ||||
|  | ||||
| // Copyright 2009 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
| // | ||||
| // Meh, we need to get rid of this shit | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| @@ -80,7 +74,7 @@ type router struct { | ||||
| 	subscribers map[string][]*subscriber | ||||
| } | ||||
|  | ||||
| // rpcRouter encapsulates functions that become a server.Router. | ||||
| // rpcRouter encapsulates functions that become a Router. | ||||
| type rpcRouter struct { | ||||
| 	h func(context.Context, Request, interface{}) error | ||||
| 	m func(context.Context, Message) error | ||||
| @@ -96,7 +90,7 @@ func (r rpcRouter) ServeRequest(ctx context.Context, req Request, rsp Response) | ||||
|  | ||||
| func newRpcRouter(opts ...RouterOption) *router { | ||||
| 	return &router{ | ||||
| 		ops:         newRouterOptions(opts...), | ||||
| 		ops:         NewRouterOptions(opts...), | ||||
| 		serviceMap:  make(map[string]*service), | ||||
| 		subscribers: make(map[string][]*subscriber), | ||||
| 	} | ||||
| @@ -180,11 +174,13 @@ func prepareMethod(method reflect.Method, logger log.Logger) *methodType { | ||||
| 		logger.Logf(log.ErrorLevel, "method %v has wrong number of outs: %v", mname, mtype.NumOut()) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// The return type of the method must be error. | ||||
| 	if returnType := mtype.Out(0); returnType != typeOfError { | ||||
| 		logger.Logf(log.ErrorLevel, "method %v returns %v not error", mname, returnType.String()) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream} | ||||
| } | ||||
|  | ||||
| @@ -195,10 +191,13 @@ func (router *router) sendResponse(sending sync.Locker, req *request, reply inte | ||||
| 	resp.msg = msg | ||||
|  | ||||
| 	resp.msg.Id = req.msg.Id | ||||
|  | ||||
| 	sending.Lock() | ||||
| 	err := cc.Write(resp.msg, reply) | ||||
| 	sending.Unlock() | ||||
|  | ||||
| 	router.freeResponse(resp) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| @@ -261,6 +260,7 @@ func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, | ||||
| 	// Invoke the method, providing a new value for the reply. | ||||
| 	fn := func(ctx context.Context, req Request, stream interface{}) error { | ||||
| 		returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)}) | ||||
|  | ||||
| 		if err := returnValues[0].Interface(); err != nil { | ||||
| 			// the function returned an error, we use that | ||||
| 			return err.(error) | ||||
| @@ -288,11 +288,14 @@ func (m *methodType) prepareContext(ctx context.Context) reflect.Value { | ||||
| 	if contextv := reflect.ValueOf(ctx); contextv.IsValid() { | ||||
| 		return contextv | ||||
| 	} | ||||
|  | ||||
| 	return reflect.Zero(m.ContextType) | ||||
| } | ||||
|  | ||||
| func (router *router) getRequest() *request { | ||||
| 	router.reqLock.Lock() | ||||
| 	defer router.reqLock.Unlock() | ||||
|  | ||||
| 	req := router.freeReq | ||||
| 	if req == nil { | ||||
| 		req = new(request) | ||||
| @@ -300,19 +303,22 @@ func (router *router) getRequest() *request { | ||||
| 		router.freeReq = req.next | ||||
| 		*req = request{} | ||||
| 	} | ||||
| 	router.reqLock.Unlock() | ||||
|  | ||||
| 	return req | ||||
| } | ||||
|  | ||||
| func (router *router) freeRequest(req *request) { | ||||
| 	router.reqLock.Lock() | ||||
| 	defer router.reqLock.Unlock() | ||||
|  | ||||
| 	req.next = router.freeReq | ||||
| 	router.freeReq = req | ||||
| 	router.reqLock.Unlock() | ||||
| } | ||||
|  | ||||
| func (router *router) getResponse() *response { | ||||
| 	router.respLock.Lock() | ||||
| 	defer router.respLock.Unlock() | ||||
|  | ||||
| 	resp := router.freeResp | ||||
| 	if resp == nil { | ||||
| 		resp = new(response) | ||||
| @@ -320,15 +326,16 @@ func (router *router) getResponse() *response { | ||||
| 		router.freeResp = resp.next | ||||
| 		*resp = response{} | ||||
| 	} | ||||
| 	router.respLock.Unlock() | ||||
|  | ||||
| 	return resp | ||||
| } | ||||
|  | ||||
| func (router *router) freeResponse(resp *response) { | ||||
| 	router.respLock.Lock() | ||||
| 	defer router.respLock.Unlock() | ||||
|  | ||||
| 	resp.next = router.freeResp | ||||
| 	router.freeResp = resp | ||||
| 	router.respLock.Unlock() | ||||
| } | ||||
|  | ||||
| func (router *router) readRequest(r Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) { | ||||
| @@ -341,8 +348,10 @@ func (router *router) readRequest(r Request) (service *service, mtype *methodTyp | ||||
| 		} | ||||
| 		// discard body | ||||
| 		cc.ReadBody(nil) | ||||
|  | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// is it a streaming request? then we don't read the body | ||||
| 	if mtype.stream { | ||||
| 		if cc.(codec.Codec).String() != "grpc" { | ||||
| @@ -359,10 +368,12 @@ func (router *router) readRequest(r Request) (service *service, mtype *methodTyp | ||||
| 		argv = reflect.New(mtype.ArgType) | ||||
| 		argIsValue = true | ||||
| 	} | ||||
|  | ||||
| 	// argv guaranteed to be a pointer now. | ||||
| 	if err = cc.ReadBody(argv.Interface()); err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if argIsValue { | ||||
| 		argv = argv.Elem() | ||||
| 	} | ||||
| @@ -370,6 +381,7 @@ func (router *router) readRequest(r Request) (service *service, mtype *methodTyp | ||||
| 	if !mtype.stream { | ||||
| 		replyv = reflect.New(mtype.ReplyType.Elem()) | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @@ -387,6 +399,7 @@ func (router *router) readHeader(cc codec.Reader) (service *service, mtype *meth | ||||
| 			return | ||||
| 		} | ||||
| 		err = errors.New("rpc: router cannot decode request: " + err.Error()) | ||||
|  | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -399,28 +412,33 @@ func (router *router) readHeader(cc codec.Reader) (service *service, mtype *meth | ||||
| 		err = errors.New("rpc: service/endpoint request ill-formed: " + req.msg.Endpoint) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Look up the request. | ||||
| 	router.mu.Lock() | ||||
| 	service = router.serviceMap[serviceMethod[0]] | ||||
| 	router.mu.Unlock() | ||||
|  | ||||
| 	if service == nil { | ||||
| 		err = errors.New("rpc: can't find service " + serviceMethod[0]) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	mtype = service.method[serviceMethod[1]] | ||||
| 	if mtype == nil { | ||||
| 		err = errors.New("rpc: can't find method " + serviceMethod[1]) | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (router *router) NewHandler(h interface{}, opts ...HandlerOption) Handler { | ||||
| 	return newRpcHandler(h, opts...) | ||||
| 	return NewRpcHandler(h, opts...) | ||||
| } | ||||
|  | ||||
| func (router *router) Handle(h Handler) error { | ||||
| 	router.mu.Lock() | ||||
| 	defer router.mu.Unlock() | ||||
|  | ||||
| 	if router.serviceMap == nil { | ||||
| 		router.serviceMap = make(map[string]*service) | ||||
| 	} | ||||
| @@ -428,6 +446,7 @@ func (router *router) Handle(h Handler) error { | ||||
| 	if len(h.Name()) == 0 { | ||||
| 		return errors.New("rpc.Handle: handler has no name") | ||||
| 	} | ||||
|  | ||||
| 	if !isExported(h.Name()) { | ||||
| 		return errors.New("rpc.Handle: type " + h.Name() + " is not exported") | ||||
| 	} | ||||
| @@ -460,6 +479,7 @@ func (router *router) Handle(h Handler) error { | ||||
|  | ||||
| 	// save handler | ||||
| 	router.serviceMap[s.name] = s | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -474,8 +494,10 @@ func (router *router) ServeRequest(ctx context.Context, r Request, rsp Response) | ||||
| 		if req != nil { | ||||
| 			router.freeRequest(req) | ||||
| 		} | ||||
|  | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec()) | ||||
| } | ||||
|  | ||||
| @@ -488,6 +510,7 @@ func (router *router) Subscribe(s Subscriber) error { | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("invalid subscriber: expected *subscriber") | ||||
| 	} | ||||
|  | ||||
| 	if len(sub.handlers) == 0 { | ||||
| 		return fmt.Errorf("invalid subscriber: no handler functions") | ||||
| 	} | ||||
| @@ -517,10 +540,9 @@ func (router *router) ProcessMessage(ctx context.Context, msg Message) (err erro | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	router.su.RLock() | ||||
| 	// get the subscribers by topic | ||||
| 	router.su.RLock() | ||||
| 	subs, ok := router.subscribers[msg.Topic()] | ||||
| 	// unlock since we only need to get the subs | ||||
| 	router.su.RUnlock() | ||||
| 	if !ok { | ||||
| 		return nil | ||||
|   | ||||
							
								
								
									
										1513
									
								
								server/rpc_server.go
									
									
									
									
									
								
							
							
						
						
									
										1513
									
								
								server/rpc_server.go
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -12,6 +12,13 @@ type waitGroup struct { | ||||
| 	gg *sync.WaitGroup | ||||
| } | ||||
|  | ||||
| // NewWaitGroup returns a new double waitgroup for global management of processes. | ||||
| func NewWaitGroup(gWg *sync.WaitGroup) *waitGroup { | ||||
| 	return &waitGroup{ | ||||
| 		gg: gWg, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *waitGroup) Add(i int) { | ||||
| 	w.lg.Add(i) | ||||
| 	if w.gg != nil { | ||||
|   | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/google/uuid" | ||||
|  | ||||
| 	"go-micro.dev/v4/codec" | ||||
| 	log "go-micro.dev/v4/logger" | ||||
| 	"go-micro.dev/v4/registry" | ||||
| @@ -140,14 +141,14 @@ var ( | ||||
| 	DefaultName                    = "go.micro.server" | ||||
| 	DefaultVersion                 = "latest" | ||||
| 	DefaultId                      = uuid.New().String() | ||||
| 	DefaultServer           Server = newRpcServer() | ||||
| 	DefaultServer           Server = NewRPCServer() | ||||
| 	DefaultRouter                  = newRpcRouter() | ||||
| 	DefaultRegisterCheck           = func(context.Context) error { return nil } | ||||
| 	DefaultRegisterInterval        = time.Second * 30 | ||||
| 	DefaultRegisterTTL             = time.Second * 90 | ||||
|  | ||||
| 	// NewServer creates a new server. | ||||
| 	NewServer func(...Option) Server = newRpcServer | ||||
| 	NewServer func(...Option) Server = NewRPCServer | ||||
| ) | ||||
|  | ||||
| // DefaultOptions returns config options for the default service. | ||||
| @@ -157,7 +158,7 @@ func DefaultOptions() Options { | ||||
|  | ||||
| func Init(opt ...Option) { | ||||
| 	if DefaultServer == nil { | ||||
| 		DefaultServer = newRpcServer(opt...) | ||||
| 		DefaultServer = NewRPCServer(opt...) | ||||
| 	} | ||||
| 	DefaultServer.Init(opt...) | ||||
| } | ||||
|   | ||||
| @@ -113,7 +113,7 @@ func (s *service) Stop() error { | ||||
| 		err = fn() | ||||
| 	} | ||||
|  | ||||
| 	if err = s.opts.Server.Stop(); err != nil { | ||||
| 	if err := s.opts.Server.Stop(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| @@ -144,6 +144,7 @@ func (s *service) Run() (err error) { | ||||
| 		if err = s.opts.Profile.Start(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		defer func() { | ||||
| 			err = s.opts.Profile.Stop() | ||||
| 			if err != nil { | ||||
|   | ||||
							
								
								
									
										286
									
								
								service_test.go
									
									
									
									
									
								
							
							
						
						
									
										286
									
								
								service_test.go
									
									
									
									
									
								
							| @@ -1,286 +0,0 @@ | ||||
| package micro | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go-micro.dev/v4/client" | ||||
| 	"go-micro.dev/v4/debug/handler" | ||||
| 	proto "go-micro.dev/v4/debug/proto" | ||||
| 	"go-micro.dev/v4/registry" | ||||
| 	"go-micro.dev/v4/server" | ||||
| 	"go-micro.dev/v4/transport" | ||||
| 	"go-micro.dev/v4/util/test" | ||||
| ) | ||||
|  | ||||
| func testShutdown(wg *sync.WaitGroup, cancel func()) { | ||||
| 	// add 1 | ||||
| 	wg.Add(1) | ||||
| 	// shutdown the service | ||||
| 	cancel() | ||||
| 	// wait for stop | ||||
| 	wg.Wait() | ||||
| } | ||||
|  | ||||
| func testService(t testing.TB, ctx context.Context, wg *sync.WaitGroup, name string) Service { | ||||
| 	// add self | ||||
| 	wg.Add(1) | ||||
|  | ||||
| 	r := registry.NewMemoryRegistry(registry.Services(test.Data)) | ||||
|  | ||||
| 	// create service | ||||
| 	srv := NewService( | ||||
| 		Name(name), | ||||
| 		Context(ctx), | ||||
| 		Registry(r), | ||||
| 		AfterStart(func() error { | ||||
| 			wg.Done() | ||||
|  | ||||
| 			return nil | ||||
| 		}), | ||||
| 		AfterStop(func() error { | ||||
| 			wg.Done() | ||||
|  | ||||
| 			return nil | ||||
| 		}), | ||||
| 	) | ||||
|  | ||||
| 	if err := RegisterHandler(srv.Server(), handler.NewHandler(srv.Client())); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	return srv | ||||
| } | ||||
|  | ||||
| func testCustomListenService(ctx context.Context, customListener net.Listener, wg *sync.WaitGroup, name string) Service { | ||||
| 	// add self | ||||
| 	wg.Add(1) | ||||
|  | ||||
| 	r := registry.NewMemoryRegistry(registry.Services(test.Data)) | ||||
|  | ||||
| 	// create service | ||||
| 	srv := NewService( | ||||
| 		Name(name), | ||||
| 		Context(ctx), | ||||
| 		Registry(r), | ||||
| 		// injection customListener | ||||
| 		AddListenOption(server.ListenOption(transport.NetListener(customListener))), | ||||
| 		AfterStart(func() error { | ||||
| 			wg.Done() | ||||
| 			return nil | ||||
| 		}), | ||||
| 		AfterStop(func() error { | ||||
| 			wg.Done() | ||||
| 			return nil | ||||
| 		}), | ||||
| 	) | ||||
|  | ||||
| 	RegisterHandler(srv.Server(), handler.NewHandler(srv.Client())) | ||||
|  | ||||
| 	return srv | ||||
| } | ||||
|  | ||||
| func testRequest(ctx context.Context, c client.Client, name string) error { | ||||
| 	// test call debug | ||||
| 	req := c.NewRequest( | ||||
| 		name, | ||||
| 		"Debug.Health", | ||||
| 		new(proto.HealthRequest), | ||||
| 	) | ||||
|  | ||||
| 	rsp := new(proto.HealthResponse) | ||||
|  | ||||
| 	err := c.Call(context.TODO(), req, rsp) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if rsp.Status != "ok" { | ||||
| 		return errors.New("service response: " + rsp.Status) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // TestService tests running and calling a service. | ||||
| func TestService(t *testing.T) { | ||||
| 	// waitgroup for server start | ||||
| 	var wg sync.WaitGroup | ||||
|  | ||||
| 	// cancellation context | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
|  | ||||
| 	// start test server | ||||
| 	service := testService(t, ctx, &wg, "test.service") | ||||
|  | ||||
| 	go func() { | ||||
| 		// wait for service to start | ||||
| 		wg.Wait() | ||||
|  | ||||
| 		// make a test call | ||||
| 		if err := testRequest(ctx, service.Client(), "test.service"); err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		// shutdown the service | ||||
| 		testShutdown(&wg, cancel) | ||||
| 	}() | ||||
|  | ||||
| 	// start service | ||||
| 	if err := service.Run(); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func benchmarkCustomListenService(b *testing.B, n int, name string) { | ||||
| 	// create custom listen | ||||
| 	customListen, err := net.Listen("tcp", server.DefaultAddress) | ||||
| 	if err != nil { | ||||
| 		b.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// stop the timer | ||||
| 	b.StopTimer() | ||||
|  | ||||
| 	// waitgroup for server start | ||||
| 	var wg sync.WaitGroup | ||||
|  | ||||
| 	// cancellation context | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
|  | ||||
| 	// create test server | ||||
| 	service := testCustomListenService(ctx, customListen, &wg, name) | ||||
|  | ||||
| 	// start the server | ||||
| 	go func() { | ||||
| 		if err := service.Run(); err != nil { | ||||
| 			b.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// wait for service to start | ||||
| 	wg.Wait() | ||||
|  | ||||
| 	// make a test call to warm the cache | ||||
| 	for i := 0; i < 10; i++ { | ||||
| 		if err := testRequest(ctx, service.Client(), name); err != nil { | ||||
| 			b.Fatal(err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// start the timer | ||||
| 	b.StartTimer() | ||||
|  | ||||
| 	// number of iterations | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		// for concurrency | ||||
| 		for j := 0; j < n; j++ { | ||||
| 			wg.Add(1) | ||||
|  | ||||
| 			go func() { | ||||
| 				err := testRequest(ctx, service.Client(), name) | ||||
| 				wg.Done() | ||||
| 				if err != nil { | ||||
| 					b.Fatal(err) | ||||
| 				} | ||||
| 			}() | ||||
| 		} | ||||
|  | ||||
| 		// wait for test completion | ||||
| 		wg.Wait() | ||||
| 	} | ||||
|  | ||||
| 	// stop the timer | ||||
| 	b.StopTimer() | ||||
|  | ||||
| 	// shutdown service | ||||
| 	testShutdown(&wg, cancel) | ||||
| } | ||||
|  | ||||
| func benchmarkService(b *testing.B, n int, name string) { | ||||
| 	// stop the timer | ||||
| 	b.StopTimer() | ||||
|  | ||||
| 	// waitgroup for server start | ||||
| 	var wg sync.WaitGroup | ||||
|  | ||||
| 	// cancellation context | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
|  | ||||
| 	// create test server | ||||
| 	service := testService(b, ctx, &wg, name) | ||||
|  | ||||
| 	// start the server | ||||
| 	go func() { | ||||
| 		if err := service.Run(); err != nil { | ||||
| 			b.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// wait for service to start | ||||
| 	wg.Wait() | ||||
|  | ||||
| 	// make a test call to warm the cache | ||||
| 	for i := 0; i < 10; i++ { | ||||
| 		if err := testRequest(ctx, service.Client(), name); err != nil { | ||||
| 			b.Fatal(err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// start the timer | ||||
| 	b.StartTimer() | ||||
|  | ||||
| 	// number of iterations | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		// for concurrency | ||||
| 		for j := 0; j < n; j++ { | ||||
| 			wg.Add(1) | ||||
|  | ||||
| 			go func() { | ||||
| 				err := testRequest(ctx, service.Client(), name) | ||||
|  | ||||
| 				wg.Done() | ||||
|  | ||||
| 				if err != nil { | ||||
| 					b.Fatal(err) | ||||
| 				} | ||||
| 			}() | ||||
| 		} | ||||
|  | ||||
| 		// wait for test completion | ||||
| 		wg.Wait() | ||||
| 	} | ||||
|  | ||||
| 	// stop the timer | ||||
| 	b.StopTimer() | ||||
|  | ||||
| 	// shutdown service | ||||
| 	testShutdown(&wg, cancel) | ||||
| } | ||||
|  | ||||
| func BenchmarkService1(b *testing.B) { | ||||
| 	benchmarkService(b, 1, "test.service.1") | ||||
| } | ||||
|  | ||||
| func BenchmarkService8(b *testing.B) { | ||||
| 	benchmarkService(b, 8, "test.service.8") | ||||
| } | ||||
|  | ||||
| func BenchmarkService16(b *testing.B) { | ||||
| 	benchmarkService(b, 16, "test.service.16") | ||||
| } | ||||
|  | ||||
| func BenchmarkService32(b *testing.B) { | ||||
| 	benchmarkService(b, 32, "test.service.32") | ||||
| } | ||||
|  | ||||
| func BenchmarkService64(b *testing.B) { | ||||
| 	benchmarkService(b, 64, "test.service.64") | ||||
| } | ||||
|  | ||||
| func BenchmarkCustomListenService1(b *testing.B) { | ||||
| 	benchmarkCustomListenService(b, 1, "test.service.1") | ||||
| } | ||||
							
								
								
									
										64
									
								
								tests/default_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								tests/default_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | ||||
| package tests | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"go-micro.dev/v4" | ||||
| 	"go-micro.dev/v4/broker" | ||||
| 	"go-micro.dev/v4/client" | ||||
| 	"go-micro.dev/v4/registry" | ||||
| 	"go-micro.dev/v4/server" | ||||
| 	"go-micro.dev/v4/transport" | ||||
| 	"go-micro.dev/v4/util/test" | ||||
| ) | ||||
|  | ||||
| func BenchmarkService(b *testing.B) { | ||||
| 	cfg := ServiceTestConfig{ | ||||
| 		Name:       "test-service", | ||||
| 		NewService: newService, | ||||
| 		Parallel:   []int{1, 8, 16, 32, 64}, | ||||
| 		Sequential: []int{0}, | ||||
| 		Streams:    []int{0}, | ||||
| 		//		PubSub:     []int{10}, | ||||
| 	} | ||||
|  | ||||
| 	cfg.Run(b) | ||||
| } | ||||
|  | ||||
| func newService(name string, opts ...micro.Option) (micro.Service, error) { | ||||
| 	r := registry.NewMemoryRegistry( | ||||
| 		registry.Services(test.Data), | ||||
| 	) | ||||
|  | ||||
| 	b := broker.NewMemoryBroker() | ||||
|  | ||||
| 	t := transport.NewHTTPTransport() | ||||
| 	c := client.NewClient( | ||||
| 		client.Transport(t), | ||||
| 		client.Broker(b), | ||||
| 	) | ||||
|  | ||||
| 	s := server.NewRPCServer( | ||||
| 		server.Name(name), | ||||
| 		server.Registry(r), | ||||
| 		server.Transport(t), | ||||
| 		server.Broker(b), | ||||
| 	) | ||||
|  | ||||
| 	if err := s.Init(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	options := []micro.Option{ | ||||
| 		micro.Name(name), | ||||
| 		micro.Server(s), | ||||
| 		micro.Client(c), | ||||
| 		micro.Registry(r), | ||||
| 		micro.Broker(b), | ||||
| 	} | ||||
| 	options = append(options, opts...) | ||||
|  | ||||
| 	srv := micro.NewService(options...) | ||||
|  | ||||
| 	return srv, nil | ||||
| } | ||||
							
								
								
									
										395
									
								
								tests/service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										395
									
								
								tests/service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,395 @@ | ||||
| // Package tests implements a testing framwork, and provides default tests. | ||||
| package tests | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/pkg/errors" | ||||
|  | ||||
| 	proto "github.com/go-micro/plugins/v4/server/grpc/proto" | ||||
|  | ||||
| 	"go-micro.dev/v4" | ||||
| 	"go-micro.dev/v4/client" | ||||
| 	"go-micro.dev/v4/debug/handler" | ||||
|  | ||||
| 	pb "go-micro.dev/v4/debug/proto" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// ErrNoTests returns no test params are set. | ||||
| 	ErrNoTests = errors.New("No tests to run, all values set to 0") | ||||
| 	testTopic  = "Test-Topic" | ||||
| 	errorTopic = "Error-Topic" | ||||
| ) | ||||
|  | ||||
| type parTest func(name string, c client.Client, p, s int, errChan chan error) | ||||
| type testFunc func(name string, c client.Client, errChan chan error) | ||||
|  | ||||
| // ServiceTestConfig allows you to easily test a service configuration by | ||||
| // running predefined tests against your custom service. You only need to | ||||
| // provide a function to create the service, and how many of which test you | ||||
| // want to run. | ||||
| // | ||||
| // The default tests provided, all running with separate parallel routines are: | ||||
| //   - Sequential Call requests | ||||
| //   - Bi-directional streaming | ||||
| //   - Pub/Sub events brokering | ||||
| // | ||||
| // You can provide an array of parallel routines to run for the request and | ||||
| // stream tests. They will be run as matrix tests, so with each possible combination. | ||||
| // Thus, in total (p * seq) + (p * streams) tests will be run. | ||||
| type ServiceTestConfig struct { | ||||
| 	// Service name to use for the tests | ||||
| 	Name string | ||||
| 	// NewService function will be called to setup the new service. | ||||
| 	// It takes in a list of options, which by default will Context and an | ||||
| 	// AfterStart with channel to signal when the service has been started. | ||||
| 	NewService func(name string, opts ...micro.Option) (micro.Service, error) | ||||
| 	// Parallel is the number of prallell routines to use for the tests. | ||||
| 	Parallel []int | ||||
| 	// Sequential is the number of sequential requests to send per parallel process. | ||||
| 	Sequential []int | ||||
| 	// Streams is the nummber of streaming messages to send over the stream per routine. | ||||
| 	Streams []int | ||||
| 	// PubSub is the number of times to publish messages to the broker per routine. | ||||
| 	PubSub []int | ||||
|  | ||||
| 	mu       sync.Mutex | ||||
| 	msgCount int | ||||
| } | ||||
|  | ||||
| // Run will start the benchmark tests. | ||||
| func (stc *ServiceTestConfig) Run(b *testing.B) { | ||||
| 	if err := stc.validate(); err != nil { | ||||
| 		b.Fatal("Failed to validate config", err) | ||||
| 	} | ||||
|  | ||||
| 	// Run routines with sequential requests | ||||
| 	stc.prepBench(b, "req", stc.runParSeqTest, stc.Sequential) | ||||
|  | ||||
| 	// Run routines with streams | ||||
| 	stc.prepBench(b, "streams", stc.runParStreamTest, stc.Streams) | ||||
|  | ||||
| 	// Run routines with pub/sub | ||||
| 	stc.prepBench(b, "pubsub", stc.runBrokerTest, stc.PubSub) | ||||
| } | ||||
|  | ||||
| // prepBench will prepare the benmark by setting the right parameters, | ||||
| // and invoking the test. | ||||
| func (stc *ServiceTestConfig) prepBench(b *testing.B, tName string, test parTest, seq []int) { | ||||
| 	par := stc.Parallel | ||||
|  | ||||
| 	// No requests needed | ||||
| 	if len(seq) == 0 || seq[0] == 0 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, parallel := range par { | ||||
| 		for _, sequential := range seq { | ||||
| 			// Create the service name for the test | ||||
| 			name := fmt.Sprintf("%s.%dp-%d%s", stc.Name, parallel, sequential, tName) | ||||
|  | ||||
| 			// Run test with parallel routines making each sequential requests | ||||
| 			test := func(name string, c client.Client, errChan chan error) { | ||||
| 				test(name, c, parallel, sequential, errChan) | ||||
| 			} | ||||
|  | ||||
| 			benchmark := func(b *testing.B) { | ||||
| 				b.ReportAllocs() | ||||
| 				stc.runBench(b, name, test) | ||||
| 			} | ||||
|  | ||||
| 			b.Logf("----------- STARTING TEST %s -----------", name) | ||||
|  | ||||
| 			// Run test, return if it fails | ||||
| 			if !b.Run(name, benchmark) { | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // runParSeqTest will make s sequential requests in p parallel routines. | ||||
| func (stc *ServiceTestConfig) runParSeqTest(name string, c client.Client, p, s int, errChan chan error) { | ||||
| 	testParallel(p, func() { | ||||
| 		// Make serial requests | ||||
| 		for z := 0; z < s; z++ { | ||||
| 			if err := testRequest(context.Background(), c, name); err != nil { | ||||
| 				errChan <- errors.Wrapf(err, "[%s] Request failed during testRequest", name) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // Handle is used as a test handler. | ||||
| func (stc *ServiceTestConfig) Handle(ctx context.Context, msg *proto.Request) error { | ||||
| 	stc.mu.Lock() | ||||
| 	stc.msgCount++ | ||||
| 	stc.mu.Unlock() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // HandleError is used as a test handler. | ||||
| func (stc *ServiceTestConfig) HandleError(ctx context.Context, msg *proto.Request) error { | ||||
| 	return errors.New("dummy error") | ||||
| } | ||||
|  | ||||
| // runBrokerTest will publish messages to the broker to test pub/sub. | ||||
| func (stc *ServiceTestConfig) runBrokerTest(name string, c client.Client, p, s int, errChan chan error) { | ||||
| 	stc.msgCount = 0 | ||||
|  | ||||
| 	testParallel(p, func() { | ||||
| 		for z := 0; z < s; z++ { | ||||
| 			msg := pb.BusMsg{Msg: "Hello from broker!"} | ||||
| 			if err := c.Publish(context.Background(), c.NewMessage(testTopic, &msg)); err != nil { | ||||
| 				errChan <- errors.Wrap(err, "failed to publish message to broker") | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			msg = pb.BusMsg{Msg: "Some message that will error"} | ||||
| 			if err := c.Publish(context.Background(), c.NewMessage(errorTopic, &msg)); err == nil { | ||||
| 				errChan <- errors.New("Publish is supposed to return an error, but got no error") | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	if stc.msgCount != s*p { | ||||
| 		errChan <- fmt.Errorf("pub/sub does not work properly, invalid message count. Expected %d messaged, but received %d", s*p, stc.msgCount) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // runParStreamTest will start streaming, and send s messages parallel in p routines. | ||||
| func (stc *ServiceTestConfig) runParStreamTest(name string, c client.Client, p, s int, errChan chan error) { | ||||
| 	testParallel(p, func() { | ||||
| 		// Create a client service | ||||
| 		srv := pb.NewDebugService(name, c) | ||||
|  | ||||
| 		// Establish a connection to server over which we start streaming | ||||
| 		bus, err := srv.MessageBus(context.Background()) | ||||
| 		if err != nil { | ||||
| 			errChan <- errors.Wrap(err, "failed to connect to message bus") | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Start streaming requests | ||||
| 		for z := 0; z < s; z++ { | ||||
| 			if err := bus.Send(&pb.BusMsg{Msg: "Hack the world!"}); err != nil { | ||||
| 				errChan <- errors.Wrap(err, "failed to send to  stream") | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			msg, err := bus.Recv() | ||||
| 			if err != nil { | ||||
| 				errChan <- errors.Wrap(err, "failed to receive message from stream") | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			expected := "Request received!" | ||||
| 			if msg.Msg != expected { | ||||
| 				errChan <- fmt.Errorf("stream returned unexpected mesage. Expected '%s', but got '%s'", expected, msg.Msg) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // validate will make sure the provided test parameters are a legal combination. | ||||
| func (stc *ServiceTestConfig) validate() error { | ||||
| 	lp, lseq, lstr := len(stc.Parallel), len(stc.Sequential), len(stc.Streams) | ||||
|  | ||||
| 	if lp == 0 || (lseq == 0 && lstr == 0) { | ||||
| 		return ErrNoTests | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // runBench will create a service with the provided stc.NewService function, | ||||
| // and run a benchmark on the test function. | ||||
| func (stc *ServiceTestConfig) runBench(b *testing.B, name string, test testFunc) { | ||||
| 	b.StopTimer() | ||||
|  | ||||
| 	// Channel to signal service has started | ||||
| 	started := make(chan struct{}) | ||||
|  | ||||
| 	// Context with cancel to stop the service | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
|  | ||||
| 	opts := []micro.Option{ | ||||
| 		micro.Context(ctx), | ||||
| 		micro.AfterStart(func() error { | ||||
| 			started <- struct{}{} | ||||
| 			return nil | ||||
| 		}), | ||||
| 	} | ||||
|  | ||||
| 	// Create a new service per test | ||||
| 	service, err := stc.NewService(name, opts...) | ||||
| 	if err != nil { | ||||
| 		b.Fatalf("failed to create service: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// Register handler | ||||
| 	if err := pb.RegisterDebugHandler(service.Server(), handler.NewHandler(service.Client())); err != nil { | ||||
| 		b.Fatalf("failed to register handler during initial service setup: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	o := service.Options() | ||||
| 	if err := o.Broker.Connect(); err != nil { | ||||
| 		b.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// a := new(testService) | ||||
| 	if err := o.Server.Subscribe(o.Server.NewSubscriber(testTopic, stc.Handle)); err != nil { | ||||
| 		b.Fatalf("[%s] Failed to register subscriber: %v", name, err) | ||||
| 	} | ||||
|  | ||||
| 	if err := o.Server.Subscribe(o.Server.NewSubscriber(errorTopic, stc.HandleError)); err != nil { | ||||
| 		b.Fatalf("[%s] Failed to register subscriber: %v", name, err) | ||||
| 	} | ||||
|  | ||||
| 	b.Logf("# == [ Service ] ==================") | ||||
| 	b.Logf("#    * Server: %s", o.Server.String()) | ||||
| 	b.Logf("#    * Client: %s", o.Client.String()) | ||||
| 	b.Logf("#    * Transport: %s", o.Transport.String()) | ||||
| 	b.Logf("#    * Broker: %s", o.Broker.String()) | ||||
| 	b.Logf("#    * Registry: %s", o.Registry.String()) | ||||
| 	b.Logf("#    * Auth: %s", o.Auth.String()) | ||||
| 	b.Logf("#    * Cache: %s", o.Cache.String()) | ||||
| 	b.Logf("#    * Runtime: %s", o.Runtime.String()) | ||||
| 	b.Logf("# ================================") | ||||
|  | ||||
| 	RunBenchmark(b, name, service, test, cancel, started) | ||||
| } | ||||
|  | ||||
| // RunBenchmark will run benchmarks on a provided service. | ||||
| // | ||||
| // A test function can be provided that will be fun b.N times. | ||||
| func RunBenchmark(b *testing.B, name string, service micro.Service, test testFunc, | ||||
| 	cancel context.CancelFunc, started chan struct{}) { | ||||
| 	b.StopTimer() | ||||
|  | ||||
| 	// Receive errors from routines on this channel | ||||
| 	errChan := make(chan error, 1) | ||||
|  | ||||
| 	// Receive singal after service has shutdown | ||||
| 	done := make(chan struct{}) | ||||
|  | ||||
| 	// Start the server | ||||
| 	go func() { | ||||
| 		b.Logf("[%s] Starting server for benchmark", name) | ||||
|  | ||||
| 		if err := service.Run(); err != nil { | ||||
| 			errChan <- errors.Wrapf(err, "[%s] Error occurred during service.Run", name) | ||||
| 		} | ||||
| 		done <- struct{}{} | ||||
| 	}() | ||||
|  | ||||
| 	sigTerm := make(chan struct{}) | ||||
|  | ||||
| 	// Benchmark routine | ||||
| 	go func() { | ||||
| 		defer func() { | ||||
| 			b.StopTimer() | ||||
|  | ||||
| 			// Shutdown service | ||||
| 			b.Logf("[%s] Shutting down", name) | ||||
| 			cancel() | ||||
|  | ||||
| 			// Wait for service to be fully stopped | ||||
| 			<-done | ||||
| 			sigTerm <- struct{}{} | ||||
| 		}() | ||||
|  | ||||
| 		// Wait for service to start | ||||
| 		<-started | ||||
|  | ||||
| 		// Give the registry more time to setup | ||||
| 		time.Sleep(time.Second) | ||||
|  | ||||
| 		b.Logf("[%s] Server started", name) | ||||
|  | ||||
| 		// Make a test call to warm the cache | ||||
| 		for i := 0; i < 10; i++ { | ||||
| 			if err := testRequest(context.Background(), service.Client(), name); err != nil { | ||||
| 				errChan <- errors.Wrapf(err, "[%s] Failure during cache warmup testRequest", name) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// Check registration | ||||
| 		services, err := service.Options().Registry.GetService(name) | ||||
| 		if err != nil || len(services) == 0 { | ||||
| 			errChan <- fmt.Errorf("service registration must have failed (%d services found), unable to get service: %w", len(services), err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Start benchmark | ||||
| 		b.Logf("[%s] Starting benchtest", name) | ||||
| 		b.ResetTimer() | ||||
| 		b.StartTimer() | ||||
|  | ||||
| 		// Number of iterations | ||||
| 		for i := 0; i < b.N; i++ { | ||||
| 			test(name, service.Client(), errChan) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// Wait for completion or catch any errors | ||||
| 	select { | ||||
| 	case err := <-errChan: | ||||
| 		b.Fatal(err) | ||||
| 	case <-sigTerm: | ||||
| 		b.Logf("[%s] Completed benchmark", name) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // testParallel will run the test function in p parallel routines. | ||||
| func testParallel(p int, test func()) { | ||||
| 	// Waitgroup to wait for requests to finish | ||||
| 	wg := sync.WaitGroup{} | ||||
|  | ||||
| 	// For concurrency | ||||
| 	for j := 0; j < p; j++ { | ||||
| 		wg.Add(1) | ||||
|  | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
|  | ||||
| 			test() | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| 	// Wait for test completion | ||||
| 	wg.Wait() | ||||
| } | ||||
|  | ||||
| // testRequest sends one test request. | ||||
| // It calls the Debug.Health endpoint, and validates if the response returned | ||||
| // contains the expected message. | ||||
| func testRequest(ctx context.Context, c client.Client, name string) error { | ||||
| 	req := c.NewRequest( | ||||
| 		name, | ||||
| 		"Debug.Health", | ||||
| 		new(pb.HealthRequest), | ||||
| 	) | ||||
|  | ||||
| 	rsp := new(pb.HealthResponse) | ||||
|  | ||||
| 	if err := c.Call(ctx, req, rsp); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if rsp.Status != "ok" { | ||||
| 		return errors.New("service response: " + rsp.Status) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										33
									
								
								transport/headers/headers.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								transport/headers/headers.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | ||||
| // headers is a package for internal micro global constants | ||||
| package headers | ||||
|  | ||||
| const ( | ||||
| 	// Message header is a header for internal message communication. | ||||
| 	Message = "Micro-Topic" | ||||
| 	// Request header is a message header for internal request communication. | ||||
| 	Request = "Micro-Service" | ||||
| 	// Error header contains an error message. | ||||
| 	Error = "Micro-Error" | ||||
| 	// Endpoint header. | ||||
| 	Endpoint = "Micro-Endpoint" | ||||
| 	// Method header. | ||||
| 	Method = "Micro-Method" | ||||
| 	// ID header. | ||||
| 	ID = "Micro-ID" | ||||
| 	// Prefix used to prefix headers. | ||||
| 	Prefix = "Micro-" | ||||
| 	// Namespace header. | ||||
| 	Namespace = "Micro-Namespace" | ||||
| 	// Protocol header. | ||||
| 	Protocol = "Micro-Protocol" | ||||
| 	// Target header. | ||||
| 	Target = "Micro-Target" | ||||
| 	// ContentType header. | ||||
| 	ContentType = "Content-Type" | ||||
| 	// SpanID header. | ||||
| 	SpanID = "Micro-Span-ID" | ||||
| 	// TraceIDKey header. | ||||
| 	TraceIDKey = "Micro-Trace-ID" | ||||
| 	// Stream header. | ||||
| 	Stream = "Micro-Stream" | ||||
| ) | ||||
							
								
								
									
										202
									
								
								transport/http_client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										202
									
								
								transport/http_client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,202 @@ | ||||
| package transport | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/pkg/errors" | ||||
|  | ||||
| 	log "go-micro.dev/v4/logger" | ||||
| 	"go-micro.dev/v4/util/buf" | ||||
| ) | ||||
|  | ||||
| type httpTransportClient struct { | ||||
| 	ht       *httpTransport | ||||
| 	addr     string | ||||
| 	conn     net.Conn | ||||
| 	dialOpts DialOptions | ||||
| 	once     sync.Once | ||||
|  | ||||
| 	sync.RWMutex | ||||
|  | ||||
| 	// request must be stored for response processing | ||||
| 	req     chan *http.Request | ||||
| 	reqList []*http.Request | ||||
| 	buff    *bufio.Reader | ||||
| 	closed  bool | ||||
|  | ||||
| 	// local/remote ip | ||||
| 	local  string | ||||
| 	remote string | ||||
| } | ||||
|  | ||||
| func (h *httpTransportClient) Local() string { | ||||
| 	return h.local | ||||
| } | ||||
|  | ||||
| func (h *httpTransportClient) Remote() string { | ||||
| 	return h.remote | ||||
| } | ||||
|  | ||||
| func (h *httpTransportClient) Send(m *Message) error { | ||||
| 	logger := h.ht.Options().Logger | ||||
|  | ||||
| 	header := make(http.Header) | ||||
| 	for k, v := range m.Header { | ||||
| 		header.Set(k, v) | ||||
| 	} | ||||
|  | ||||
| 	b := buf.New(bytes.NewBuffer(m.Body)) | ||||
| 	defer func() { | ||||
| 		if err := b.Close(); err != nil { | ||||
| 			logger.Logf(log.ErrorLevel, "failed to close buffer: %v", err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	req := &http.Request{ | ||||
| 		Method: http.MethodPost, | ||||
| 		URL: &url.URL{ | ||||
| 			Scheme: "http", | ||||
| 			Host:   h.addr, | ||||
| 		}, | ||||
| 		Header:        header, | ||||
| 		Body:          b, | ||||
| 		ContentLength: int64(b.Len()), | ||||
| 		Host:          h.addr, | ||||
| 		Close:         h.dialOpts.ConnClose, | ||||
| 	} | ||||
|  | ||||
| 	if !h.dialOpts.Stream { | ||||
| 		h.Lock() | ||||
| 		if h.closed { | ||||
| 			h.Unlock() | ||||
| 			return io.EOF | ||||
| 		} | ||||
|  | ||||
| 		h.reqList = append(h.reqList, req) | ||||
|  | ||||
| 		select { | ||||
| 		case h.req <- h.reqList[0]: | ||||
| 			h.reqList = h.reqList[1:] | ||||
| 		default: | ||||
| 		} | ||||
| 		h.Unlock() | ||||
| 	} | ||||
|  | ||||
| 	// set timeout if its greater than 0 | ||||
| 	if h.ht.opts.Timeout > time.Duration(0) { | ||||
| 		if err := h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return req.Write(h.conn) | ||||
|  | ||||
| } | ||||
|  | ||||
| // Recv receives a message. | ||||
| func (h *httpTransportClient) Recv(msg *Message) (err error) { | ||||
| 	if msg == nil { | ||||
| 		return errors.New("message passed in is nil") | ||||
| 	} | ||||
|  | ||||
| 	var req *http.Request | ||||
|  | ||||
| 	if !h.dialOpts.Stream { | ||||
| 		rc, ok := <-h.req | ||||
| 		if !ok { | ||||
| 			h.Lock() | ||||
| 			if len(h.reqList) == 0 { | ||||
| 				h.Unlock() | ||||
| 				return io.EOF | ||||
| 			} | ||||
|  | ||||
| 			rc = h.reqList[0] | ||||
| 			h.reqList = h.reqList[1:] | ||||
| 			h.Unlock() | ||||
| 		} | ||||
|  | ||||
| 		req = rc | ||||
| 	} | ||||
|  | ||||
| 	// set timeout if its greater than 0 | ||||
| 	if h.ht.opts.Timeout > time.Duration(0) { | ||||
| 		if err = h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	h.Lock() | ||||
| 	defer h.Unlock() | ||||
|  | ||||
| 	if h.closed { | ||||
| 		return io.EOF | ||||
| 	} | ||||
|  | ||||
| 	rsp, err := http.ReadResponse(h.buff, req) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	defer func() { | ||||
| 		if err = rsp.Body.Close(); err != nil { | ||||
| 			err = errors.Wrap(err, "failed to close body") | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	b, err := io.ReadAll(rsp.Body) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if rsp.StatusCode != http.StatusOK { | ||||
| 		return errors.New(rsp.Status + ": " + string(b)) | ||||
| 	} | ||||
|  | ||||
| 	msg.Body = b | ||||
|  | ||||
| 	if msg.Header == nil { | ||||
| 		msg.Header = make(map[string]string, len(rsp.Header)) | ||||
| 	} | ||||
|  | ||||
| 	for k, v := range rsp.Header { | ||||
| 		if len(v) > 0 { | ||||
| 			msg.Header[k] = v[0] | ||||
| 		} else { | ||||
| 			msg.Header[k] = "" | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransportClient) Close() error { | ||||
| 	if !h.dialOpts.Stream { | ||||
| 		h.once.Do(func() { | ||||
| 			h.Lock() | ||||
| 			h.buff.Reset(nil) | ||||
| 			h.closed = true | ||||
| 			h.Unlock() | ||||
| 			close(h.req) | ||||
| 		}) | ||||
|  | ||||
| 		return h.conn.Close() | ||||
| 	} | ||||
|  | ||||
| 	err := h.conn.Close() | ||||
| 	h.once.Do(func() { | ||||
| 		h.Lock() | ||||
| 		h.buff.Reset(nil) | ||||
| 		h.closed = true | ||||
| 		h.Unlock() | ||||
| 		close(h.req) | ||||
| 	}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
							
								
								
									
										132
									
								
								transport/http_listener.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								transport/http_listener.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,132 @@ | ||||
| package transport | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
|  | ||||
| 	log "go-micro.dev/v4/logger" | ||||
|  | ||||
| 	"golang.org/x/net/http2" | ||||
| 	"golang.org/x/net/http2/h2c" | ||||
| ) | ||||
|  | ||||
| type httpTransportListener struct { | ||||
| 	ht       *httpTransport | ||||
| 	listener net.Listener | ||||
| } | ||||
|  | ||||
| func (h *httpTransportListener) Addr() string { | ||||
| 	return h.listener.Addr().String() | ||||
| } | ||||
|  | ||||
| func (h *httpTransportListener) Close() error { | ||||
| 	return h.listener.Close() | ||||
| } | ||||
|  | ||||
| func (h *httpTransportListener) Accept(fn func(Socket)) error { | ||||
| 	// Create handler mux | ||||
| 	// TODO: see if we should make a plugin out of the mux | ||||
| 	mux := http.NewServeMux() | ||||
|  | ||||
| 	// Register our transport handler | ||||
| 	mux.HandleFunc("/", h.newHandler(fn)) | ||||
|  | ||||
| 	// Get optional handlers | ||||
| 	// TODO: This needs to be documented clearer, and examples provided | ||||
| 	if h.ht.opts.Context != nil { | ||||
| 		handlers, ok := h.ht.opts.Context.Value("http_handlers").(map[string]http.Handler) | ||||
| 		if ok { | ||||
| 			for pattern, handler := range handlers { | ||||
| 				mux.Handle(pattern, handler) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Server ONLY supports HTTP1 + H2C | ||||
| 	srv := &http.Server{ | ||||
| 		Handler:           mux, | ||||
| 		ReadHeaderTimeout: time.Second * 5, | ||||
| 	} | ||||
|  | ||||
| 	// insecure connection use h2c | ||||
| 	if !(h.ht.opts.Secure || h.ht.opts.TLSConfig != nil) { | ||||
| 		srv.Handler = h2c.NewHandler(mux, &http2.Server{}) | ||||
| 	} | ||||
|  | ||||
| 	return srv.Serve(h.listener) | ||||
| } | ||||
|  | ||||
| // newHandler creates a new HTTP transport handler passed to the mux. | ||||
| func (h *httpTransportListener) newHandler(serveConn func(Socket)) func(rsp http.ResponseWriter, req *http.Request) { | ||||
| 	logger := h.ht.opts.Logger | ||||
|  | ||||
| 	return func(rsp http.ResponseWriter, req *http.Request) { | ||||
| 		var ( | ||||
| 			buf *bufio.ReadWriter | ||||
| 			con net.Conn | ||||
| 		) | ||||
|  | ||||
| 		// HTTP1: read a regular request | ||||
| 		if req.ProtoMajor == 1 { | ||||
| 			b, err := io.ReadAll(req.Body) | ||||
| 			if err != nil { | ||||
| 				http.Error(rsp, err.Error(), http.StatusInternalServerError) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			req.Body = io.NopCloser(bytes.NewReader(b)) | ||||
|  | ||||
| 			// Hijack the conn | ||||
| 			// We also don't close the connection here, as it will be closed by | ||||
| 			// the httpTransportSocket | ||||
| 			hj, ok := rsp.(http.Hijacker) | ||||
| 			if !ok { | ||||
| 				// We're screwed | ||||
| 				http.Error(rsp, "cannot serve conn", http.StatusInternalServerError) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			conn, bufrw, err := hj.Hijack() | ||||
| 			if err != nil { | ||||
| 				http.Error(rsp, err.Error(), http.StatusInternalServerError) | ||||
| 				return | ||||
| 			} | ||||
| 			defer func() { | ||||
| 				if err := conn.Close(); err != nil { | ||||
| 					logger.Logf(log.ErrorLevel, "Failed to close TCP connection: %v", err) | ||||
| 				} | ||||
| 			}() | ||||
|  | ||||
| 			buf = bufrw | ||||
| 			con = conn | ||||
| 		} | ||||
|  | ||||
| 		// Buffered reader | ||||
| 		bufr := bufio.NewReader(req.Body) | ||||
|  | ||||
| 		// Save the request | ||||
| 		ch := make(chan *http.Request, 1) | ||||
| 		ch <- req | ||||
|  | ||||
| 		// Create a new transport socket | ||||
| 		sock := &httpTransportSocket{ | ||||
| 			ht:     h.ht, | ||||
| 			w:      rsp, | ||||
| 			r:      req, | ||||
| 			rw:     buf, | ||||
| 			buf:    bufr, | ||||
| 			ch:     ch, | ||||
| 			conn:   con, | ||||
| 			local:  h.Addr(), | ||||
| 			remote: req.RemoteAddr, | ||||
| 			closed: make(chan bool), | ||||
| 		} | ||||
|  | ||||
| 		// Execute the socket | ||||
| 		serveConn(sock) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										263
									
								
								transport/http_socket.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								transport/http_socket.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,263 @@ | ||||
| package transport | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
|  | ||||
| type httpTransportSocket struct { | ||||
| 	ht *httpTransport | ||||
| 	w  http.ResponseWriter | ||||
| 	r  *http.Request | ||||
| 	rw *bufio.ReadWriter | ||||
|  | ||||
| 	mtx sync.RWMutex | ||||
|  | ||||
| 	// the hijacked when using http 1 | ||||
| 	conn net.Conn | ||||
| 	// for the first request | ||||
| 	ch chan *http.Request | ||||
|  | ||||
| 	// h2 things | ||||
| 	buf *bufio.Reader | ||||
| 	// indicate if socket is closed | ||||
| 	closed chan bool | ||||
|  | ||||
| 	// local/remote ip | ||||
| 	local  string | ||||
| 	remote string | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) Local() string { | ||||
| 	return h.local | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) Remote() string { | ||||
| 	return h.remote | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) Recv(msg *Message) error { | ||||
| 	if msg == nil { | ||||
| 		return errors.New("message passed in is nil") | ||||
| 	} | ||||
|  | ||||
| 	if msg.Header == nil { | ||||
| 		msg.Header = make(map[string]string, len(h.r.Header)) | ||||
| 	} | ||||
|  | ||||
| 	if h.r.ProtoMajor == 1 { | ||||
| 		return h.recvHTTP1(msg) | ||||
| 	} | ||||
|  | ||||
| 	return h.recvHTTP2(msg) | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) Send(msg *Message) error { | ||||
| 	// we need to lock to protect the write | ||||
| 	h.mtx.RLock() | ||||
| 	defer h.mtx.RUnlock() | ||||
|  | ||||
| 	if h.r.ProtoMajor == 1 { | ||||
| 		return h.sendHTTP1(msg) | ||||
| 	} | ||||
|  | ||||
| 	return h.sendHTTP2(msg) | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) Close() error { | ||||
| 	h.mtx.Lock() | ||||
| 	defer h.mtx.Unlock() | ||||
|  | ||||
| 	select { | ||||
| 	case <-h.closed: | ||||
| 		return nil | ||||
| 	default: | ||||
| 		// Close the channel | ||||
| 		close(h.closed) | ||||
|  | ||||
| 		// Close the buffer | ||||
| 		if err := h.r.Body.Close(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) error(m *Message) error { | ||||
| 	if h.r.ProtoMajor == 1 { | ||||
| 		rsp := &http.Response{ | ||||
| 			Header:        make(http.Header), | ||||
| 			Body:          io.NopCloser(bytes.NewReader(m.Body)), | ||||
| 			Status:        "500 Internal Server Error", | ||||
| 			StatusCode:    http.StatusInternalServerError, | ||||
| 			Proto:         "HTTP/1.1", | ||||
| 			ProtoMajor:    1, | ||||
| 			ProtoMinor:    1, | ||||
| 			ContentLength: int64(len(m.Body)), | ||||
| 		} | ||||
|  | ||||
| 		for k, v := range m.Header { | ||||
| 			rsp.Header.Set(k, v) | ||||
| 		} | ||||
|  | ||||
| 		return rsp.Write(h.conn) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) recvHTTP1(msg *Message) error { | ||||
| 	// set timeout if its greater than 0 | ||||
| 	if h.ht.opts.Timeout > time.Duration(0) { | ||||
| 		if err := h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil { | ||||
| 			return errors.Wrap(err, "failed to set deadline") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var req *http.Request | ||||
|  | ||||
| 	select { | ||||
| 	// get first request | ||||
| 	case req = <-h.ch: | ||||
| 	// read next request | ||||
| 	default: | ||||
| 		rr, err := http.ReadRequest(h.rw.Reader) | ||||
| 		if err != nil { | ||||
| 			return errors.Wrap(err, "failed to read request") | ||||
| 		} | ||||
|  | ||||
| 		req = rr | ||||
| 	} | ||||
|  | ||||
| 	// read body | ||||
| 	b, err := io.ReadAll(req.Body) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "failed to read body") | ||||
| 	} | ||||
|  | ||||
| 	// set body | ||||
| 	if err := req.Body.Close(); err != nil { | ||||
| 		return errors.Wrap(err, "failed to close body") | ||||
| 	} | ||||
|  | ||||
| 	msg.Body = b | ||||
|  | ||||
| 	// set headers | ||||
| 	for k, v := range req.Header { | ||||
| 		if len(v) > 0 { | ||||
| 			msg.Header[k] = v[0] | ||||
| 		} else { | ||||
| 			msg.Header[k] = "" | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// return early early | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) recvHTTP2(msg *Message) error { | ||||
| 	// only process if the socket is open | ||||
| 	select { | ||||
| 	case <-h.closed: | ||||
| 		return io.EOF | ||||
| 	default: | ||||
| 	} | ||||
|  | ||||
| 	// read streaming body | ||||
|  | ||||
| 	// set max buffer size | ||||
| 	s := h.ht.opts.BuffSizeH2 | ||||
| 	if s == 0 { | ||||
| 		s = DefaultBufSizeH2 | ||||
| 	} | ||||
|  | ||||
| 	buf := make([]byte, s) | ||||
|  | ||||
| 	// read the request body | ||||
| 	n, err := h.buf.Read(buf) | ||||
| 	// not an eof error | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// check if we have data | ||||
| 	if n > 0 { | ||||
| 		msg.Body = buf[:n] | ||||
| 	} | ||||
|  | ||||
| 	// set headers | ||||
| 	for k, v := range h.r.Header { | ||||
| 		if len(v) > 0 { | ||||
| 			msg.Header[k] = v[0] | ||||
| 		} else { | ||||
| 			msg.Header[k] = "" | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// set path | ||||
| 	msg.Header[":path"] = h.r.URL.Path | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) sendHTTP1(msg *Message) error { | ||||
| 	// make copy of header | ||||
| 	hdr := make(http.Header) | ||||
| 	for k, v := range h.r.Header { | ||||
| 		hdr[k] = v | ||||
| 	} | ||||
|  | ||||
| 	rsp := &http.Response{ | ||||
| 		Header:        hdr, | ||||
| 		Body:          io.NopCloser(bytes.NewReader(msg.Body)), | ||||
| 		Status:        "200 OK", | ||||
| 		StatusCode:    http.StatusOK, | ||||
| 		Proto:         "HTTP/1.1", | ||||
| 		ProtoMajor:    1, | ||||
| 		ProtoMinor:    1, | ||||
| 		ContentLength: int64(len(msg.Body)), | ||||
| 	} | ||||
|  | ||||
| 	for k, v := range msg.Header { | ||||
| 		rsp.Header.Set(k, v) | ||||
| 	} | ||||
|  | ||||
| 	// set timeout if its greater than 0 | ||||
| 	if h.ht.opts.Timeout > time.Duration(0) { | ||||
| 		if err := h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return rsp.Write(h.conn) | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) sendHTTP2(msg *Message) error { | ||||
| 	// only process if the socket is open | ||||
| 	select { | ||||
| 	case <-h.closed: | ||||
| 		return io.EOF | ||||
| 	default: | ||||
| 	} | ||||
|  | ||||
| 	// set headers | ||||
| 	for k, v := range msg.Header { | ||||
| 		h.w.Header().Set(k, v) | ||||
| 	} | ||||
|  | ||||
| 	// write request | ||||
| 	_, err := h.w.Write(msg.Body) | ||||
|  | ||||
| 	// flush the trailers | ||||
| 	h.w.(http.Flusher).Flush() | ||||
|  | ||||
| 	return err | ||||
| } | ||||
| @@ -2,529 +2,41 @@ package transport | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"crypto/tls" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"go-micro.dev/v4/logger" | ||||
| 	maddr "go-micro.dev/v4/util/addr" | ||||
| 	"go-micro.dev/v4/util/buf" | ||||
| 	mnet "go-micro.dev/v4/util/net" | ||||
| 	mls "go-micro.dev/v4/util/tls" | ||||
| 	"golang.org/x/net/http2" | ||||
| 	"golang.org/x/net/http2/h2c" | ||||
| ) | ||||
|  | ||||
| type httpTransport struct { | ||||
| 	opts Options | ||||
| } | ||||
|  | ||||
| type httpTransportClient struct { | ||||
| 	ht       *httpTransport | ||||
| 	addr     string | ||||
| 	conn     net.Conn | ||||
| 	dialOpts DialOptions | ||||
| 	once     sync.Once | ||||
| func NewHTTPTransport(opts ...Option) *httpTransport { | ||||
| 	options := Options{ | ||||
| 		BuffSizeH2: DefaultBufSizeH2, | ||||
| 		Logger:     logger.DefaultLogger, | ||||
| 	} | ||||
|  | ||||
| 	sync.RWMutex | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
|  | ||||
| 	// request must be stored for response processing | ||||
| 	r      chan *http.Request | ||||
| 	bl     []*http.Request | ||||
| 	buff   *bufio.Reader | ||||
| 	closed bool | ||||
|  | ||||
| 	// local/remote ip | ||||
| 	local  string | ||||
| 	remote string | ||||
| 	return &httpTransport{opts: options} | ||||
| } | ||||
|  | ||||
| type httpTransportSocket struct { | ||||
| 	ht *httpTransport | ||||
| 	w  http.ResponseWriter | ||||
| 	r  *http.Request | ||||
| 	rw *bufio.ReadWriter | ||||
|  | ||||
| 	mtx sync.RWMutex | ||||
|  | ||||
| 	// the hijacked when using http 1 | ||||
| 	conn net.Conn | ||||
| 	// for the first request | ||||
| 	ch chan *http.Request | ||||
|  | ||||
| 	// h2 things | ||||
| 	buf *bufio.Reader | ||||
| 	// indicate if socket is closed | ||||
| 	closed chan bool | ||||
|  | ||||
| 	// local/remote ip | ||||
| 	local  string | ||||
| 	remote string | ||||
| } | ||||
|  | ||||
| type httpTransportListener struct { | ||||
| 	ht       *httpTransport | ||||
| 	listener net.Listener | ||||
| } | ||||
|  | ||||
| func (h *httpTransportClient) Local() string { | ||||
| 	return h.local | ||||
| } | ||||
|  | ||||
| func (h *httpTransportClient) Remote() string { | ||||
| 	return h.remote | ||||
| } | ||||
|  | ||||
| func (h *httpTransportClient) Send(m *Message) error { | ||||
| 	header := make(http.Header) | ||||
|  | ||||
| 	for k, v := range m.Header { | ||||
| 		header.Set(k, v) | ||||
| 	} | ||||
|  | ||||
| 	b := buf.New(bytes.NewBuffer(m.Body)) | ||||
| 	defer b.Close() | ||||
|  | ||||
| 	req := &http.Request{ | ||||
| 		Method: http.MethodPost, | ||||
| 		URL: &url.URL{ | ||||
| 			Scheme: "http", | ||||
| 			Host:   h.addr, | ||||
| 		}, | ||||
| 		Header:        header, | ||||
| 		Body:          b, | ||||
| 		ContentLength: int64(b.Len()), | ||||
| 		Host:          h.addr, | ||||
| 	} | ||||
|  | ||||
| 	if !h.dialOpts.Stream { | ||||
| 		h.Lock() | ||||
| 		if h.closed { | ||||
| 			h.Unlock() | ||||
| 			return io.EOF | ||||
| 		} | ||||
|  | ||||
| 		h.bl = append(h.bl, req) | ||||
|  | ||||
| 		select { | ||||
| 		case h.r <- h.bl[0]: | ||||
| 			h.bl = h.bl[1:] | ||||
| 		default: | ||||
| 		} | ||||
| 		h.Unlock() | ||||
| 	} | ||||
|  | ||||
| 	// set timeout if its greater than 0 | ||||
| 	if h.ht.opts.Timeout > time.Duration(0) { | ||||
| 		if err := h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return req.Write(h.conn) | ||||
| } | ||||
|  | ||||
| // Recv receives a message. | ||||
| func (h *httpTransportClient) Recv(msg *Message) error { | ||||
| 	if msg == nil { | ||||
| 		return errors.New("message passed in is nil") | ||||
| 	} | ||||
|  | ||||
| 	var req *http.Request | ||||
|  | ||||
| 	if !h.dialOpts.Stream { | ||||
| 		rc, ok := <-h.r | ||||
| 		if !ok { | ||||
| 			h.Lock() | ||||
| 			if len(h.bl) == 0 { | ||||
| 				h.Unlock() | ||||
| 				return io.EOF | ||||
| 			} | ||||
|  | ||||
| 			rc = h.bl[0] | ||||
| 			h.bl = h.bl[1:] | ||||
| 			h.Unlock() | ||||
| 		} | ||||
|  | ||||
| 		req = rc | ||||
| 	} | ||||
|  | ||||
| 	// set timeout if its greater than 0 | ||||
| 	if h.ht.opts.Timeout > time.Duration(0) { | ||||
| 		if err := h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	h.Lock() | ||||
| 	defer h.Unlock() | ||||
|  | ||||
| 	if h.closed { | ||||
| 		return io.EOF | ||||
| 	} | ||||
|  | ||||
| 	rsp, err := http.ReadResponse(h.buff, req) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	defer rsp.Body.Close() | ||||
|  | ||||
| 	b, err := io.ReadAll(rsp.Body) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if rsp.StatusCode != http.StatusOK { | ||||
| 		return errors.New(rsp.Status + ": " + string(b)) | ||||
| 	} | ||||
|  | ||||
| 	msg.Body = b | ||||
|  | ||||
| 	if msg.Header == nil { | ||||
| 		msg.Header = make(map[string]string, len(rsp.Header)) | ||||
| 	} | ||||
|  | ||||
| 	for k, v := range rsp.Header { | ||||
| 		if len(v) > 0 { | ||||
| 			msg.Header[k] = v[0] | ||||
| 		} else { | ||||
| 			msg.Header[k] = "" | ||||
| 		} | ||||
| func (h *httpTransport) Init(opts ...Option) error { | ||||
| 	for _, o := range opts { | ||||
| 		o(&h.opts) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransportClient) Close() error { | ||||
| 	if !h.dialOpts.Stream { | ||||
| 		h.once.Do(func() { | ||||
| 			h.Lock() | ||||
| 			h.buff.Reset(nil) | ||||
| 			h.closed = true | ||||
| 			h.Unlock() | ||||
| 			close(h.r) | ||||
| 		}) | ||||
|  | ||||
| 		return h.conn.Close() | ||||
| 	} | ||||
|  | ||||
| 	err := h.conn.Close() | ||||
| 	h.once.Do(func() { | ||||
| 		h.Lock() | ||||
| 		h.buff.Reset(nil) | ||||
| 		h.closed = true | ||||
| 		h.Unlock() | ||||
| 		close(h.r) | ||||
| 	}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) Local() string { | ||||
| 	return h.local | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) Remote() string { | ||||
| 	return h.remote | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) Recv(msg *Message) error { | ||||
| 	if msg == nil { | ||||
| 		return errors.New("message passed in is nil") | ||||
| 	} | ||||
|  | ||||
| 	if msg.Header == nil { | ||||
| 		msg.Header = make(map[string]string, len(h.r.Header)) | ||||
| 	} | ||||
|  | ||||
| 	// process http 1 | ||||
| 	if h.r.ProtoMajor == 1 { | ||||
| 		// set timeout if its greater than 0 | ||||
| 		if h.ht.opts.Timeout > time.Duration(0) { | ||||
| 			h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)) | ||||
| 		} | ||||
|  | ||||
| 		var r *http.Request | ||||
|  | ||||
| 		select { | ||||
| 		// get first request | ||||
| 		case r = <-h.ch: | ||||
| 		// read next request | ||||
| 		default: | ||||
| 			rr, err := http.ReadRequest(h.rw.Reader) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
|  | ||||
| 			r = rr | ||||
| 		} | ||||
|  | ||||
| 		// read body | ||||
| 		b, err := io.ReadAll(r.Body) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		// set body | ||||
| 		r.Body.Close() | ||||
|  | ||||
| 		msg.Body = b | ||||
|  | ||||
| 		// set headers | ||||
| 		for k, v := range r.Header { | ||||
| 			if len(v) > 0 { | ||||
| 				msg.Header[k] = v[0] | ||||
| 			} else { | ||||
| 				msg.Header[k] = "" | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// return early early | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// only process if the socket is open | ||||
| 	select { | ||||
| 	case <-h.closed: | ||||
| 		return io.EOF | ||||
| 	default: | ||||
| 		// no op | ||||
| 	} | ||||
|  | ||||
| 	// processing http2 request | ||||
| 	// read streaming body | ||||
|  | ||||
| 	// set max buffer size | ||||
| 	// TODO: adjustable buffer size | ||||
| 	buf := make([]byte, 4*1024*1024) | ||||
|  | ||||
| 	// read the request body | ||||
| 	n, err := h.buf.Read(buf) | ||||
| 	// not an eof error | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// check if we have data | ||||
| 	if n > 0 { | ||||
| 		msg.Body = buf[:n] | ||||
| 	} | ||||
|  | ||||
| 	// set headers | ||||
| 	for k, v := range h.r.Header { | ||||
| 		if len(v) > 0 { | ||||
| 			msg.Header[k] = v[0] | ||||
| 		} else { | ||||
| 			msg.Header[k] = "" | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// set path | ||||
| 	msg.Header[":path"] = h.r.URL.Path | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) Send(msg *Message) error { | ||||
| 	if h.r.ProtoMajor == 1 { | ||||
| 		// make copy of header | ||||
| 		hdr := make(http.Header) | ||||
| 		for k, v := range h.r.Header { | ||||
| 			hdr[k] = v | ||||
| 		} | ||||
|  | ||||
| 		rsp := &http.Response{ | ||||
| 			Header:        hdr, | ||||
| 			Body:          io.NopCloser(bytes.NewReader(msg.Body)), | ||||
| 			Status:        "200 OK", | ||||
| 			StatusCode:    http.StatusOK, | ||||
| 			Proto:         "HTTP/1.1", | ||||
| 			ProtoMajor:    1, | ||||
| 			ProtoMinor:    1, | ||||
| 			ContentLength: int64(len(msg.Body)), | ||||
| 		} | ||||
|  | ||||
| 		for k, v := range msg.Header { | ||||
| 			rsp.Header.Set(k, v) | ||||
| 		} | ||||
|  | ||||
| 		// set timeout if its greater than 0 | ||||
| 		if h.ht.opts.Timeout > time.Duration(0) { | ||||
| 			h.conn.SetDeadline(time.Now().Add(h.ht.opts.Timeout)) | ||||
| 		} | ||||
|  | ||||
| 		return rsp.Write(h.conn) | ||||
| 	} | ||||
|  | ||||
| 	// only process if the socket is open | ||||
| 	select { | ||||
| 	case <-h.closed: | ||||
| 		return io.EOF | ||||
| 	default: | ||||
| 		// no op | ||||
| 	} | ||||
|  | ||||
| 	// we need to lock to protect the write | ||||
| 	h.mtx.RLock() | ||||
| 	defer h.mtx.RUnlock() | ||||
|  | ||||
| 	// set headers | ||||
| 	for k, v := range msg.Header { | ||||
| 		h.w.Header().Set(k, v) | ||||
| 	} | ||||
|  | ||||
| 	// write request | ||||
| 	_, err := h.w.Write(msg.Body) | ||||
|  | ||||
| 	// flush the trailers | ||||
| 	h.w.(http.Flusher).Flush() | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) error(m *Message) error { | ||||
| 	if h.r.ProtoMajor == 1 { | ||||
| 		rsp := &http.Response{ | ||||
| 			Header:        make(http.Header), | ||||
| 			Body:          io.NopCloser(bytes.NewReader(m.Body)), | ||||
| 			Status:        "500 Internal Server Error", | ||||
| 			StatusCode:    http.StatusInternalServerError, | ||||
| 			Proto:         "HTTP/1.1", | ||||
| 			ProtoMajor:    1, | ||||
| 			ProtoMinor:    1, | ||||
| 			ContentLength: int64(len(m.Body)), | ||||
| 		} | ||||
|  | ||||
| 		for k, v := range m.Header { | ||||
| 			rsp.Header.Set(k, v) | ||||
| 		} | ||||
|  | ||||
| 		return rsp.Write(h.conn) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransportSocket) Close() error { | ||||
| 	h.mtx.Lock() | ||||
| 	defer h.mtx.Unlock() | ||||
| 	select { | ||||
| 	case <-h.closed: | ||||
| 		return nil | ||||
| 	default: | ||||
| 		// close the channel | ||||
| 		close(h.closed) | ||||
|  | ||||
| 		// close the buffer | ||||
| 		h.r.Body.Close() | ||||
|  | ||||
| 		// close the connection | ||||
| 		if h.r.ProtoMajor == 1 { | ||||
| 			return h.conn.Close() | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransportListener) Addr() string { | ||||
| 	return h.listener.Addr().String() | ||||
| } | ||||
|  | ||||
| func (h *httpTransportListener) Close() error { | ||||
| 	return h.listener.Close() | ||||
| } | ||||
|  | ||||
| func (h *httpTransportListener) Accept(fn func(Socket)) error { | ||||
| 	// create handler mux | ||||
| 	mux := http.NewServeMux() | ||||
|  | ||||
| 	// register our transport handler | ||||
| 	mux.HandleFunc("/", func(rsp http.ResponseWriter, req *http.Request) { | ||||
| 		var buf *bufio.ReadWriter | ||||
| 		var con net.Conn | ||||
|  | ||||
| 		// read a regular request | ||||
| 		if req.ProtoMajor == 1 { | ||||
| 			b, err := io.ReadAll(req.Body) | ||||
| 			if err != nil { | ||||
| 				http.Error(rsp, err.Error(), http.StatusInternalServerError) | ||||
| 				return | ||||
| 			} | ||||
| 			req.Body = io.NopCloser(bytes.NewReader(b)) | ||||
| 			// hijack the conn | ||||
| 			hj, ok := rsp.(http.Hijacker) | ||||
| 			if !ok { | ||||
| 				// we're screwed | ||||
| 				http.Error(rsp, "cannot serve conn", http.StatusInternalServerError) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			conn, bufrw, err := hj.Hijack() | ||||
| 			if err != nil { | ||||
| 				http.Error(rsp, err.Error(), http.StatusInternalServerError) | ||||
| 				return | ||||
| 			} | ||||
| 			defer conn.Close() | ||||
| 			buf = bufrw | ||||
| 			con = conn | ||||
| 		} | ||||
|  | ||||
| 		// buffered reader | ||||
| 		bufr := bufio.NewReader(req.Body) | ||||
|  | ||||
| 		// save the request | ||||
| 		ch := make(chan *http.Request, 1) | ||||
| 		ch <- req | ||||
|  | ||||
| 		// create a new transport socket | ||||
| 		sock := &httpTransportSocket{ | ||||
| 			ht:     h.ht, | ||||
| 			w:      rsp, | ||||
| 			r:      req, | ||||
| 			rw:     buf, | ||||
| 			buf:    bufr, | ||||
| 			ch:     ch, | ||||
| 			conn:   con, | ||||
| 			local:  h.Addr(), | ||||
| 			remote: req.RemoteAddr, | ||||
| 			closed: make(chan bool), | ||||
| 		} | ||||
|  | ||||
| 		// execute the socket | ||||
| 		fn(sock) | ||||
| 	}) | ||||
|  | ||||
| 	// get optional handlers | ||||
| 	if h.ht.opts.Context != nil { | ||||
| 		handlers, ok := h.ht.opts.Context.Value("http_handlers").(map[string]http.Handler) | ||||
| 		if ok { | ||||
| 			for pattern, handler := range handlers { | ||||
| 				mux.Handle(pattern, handler) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// default http2 server | ||||
| 	srv := &http.Server{ | ||||
| 		Handler: mux, | ||||
| 	} | ||||
|  | ||||
| 	// insecure connection use h2c | ||||
| 	if !(h.ht.opts.Secure || h.ht.opts.TLSConfig != nil) { | ||||
| 		srv.Handler = h2c.NewHandler(mux, &http2.Server{}) | ||||
| 	} | ||||
|  | ||||
| 	// begin serving | ||||
| 	return srv.Serve(h.listener) | ||||
| } | ||||
|  | ||||
| func (h *httpTransport) Dial(addr string, opts ...DialOption) (Client, error) { | ||||
| 	dopts := DialOptions{ | ||||
| 		Timeout: DefaultDialTimeout, | ||||
| @@ -539,12 +51,11 @@ func (h *httpTransport) Dial(addr string, opts ...DialOption) (Client, error) { | ||||
| 		err  error | ||||
| 	) | ||||
|  | ||||
| 	// TODO: support dial option here rather than using internal config | ||||
| 	if h.opts.Secure || h.opts.TLSConfig != nil { | ||||
| 		config := h.opts.TLSConfig | ||||
| 		if config == nil { | ||||
| 			config = &tls.Config{ | ||||
| 				InsecureSkipVerify: true, | ||||
| 				InsecureSkipVerify: dopts.InsecureSkipVerify, | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @@ -569,7 +80,7 @@ func (h *httpTransport) Dial(addr string, opts ...DialOption) (Client, error) { | ||||
| 		conn:     conn, | ||||
| 		buff:     bufio.NewReader(conn), | ||||
| 		dialOpts: dopts, | ||||
| 		r:        make(chan *http.Request, 100), | ||||
| 		req:      make(chan *http.Request, 100), | ||||
| 		local:    conn.LocalAddr().String(), | ||||
| 		remote:   conn.RemoteAddr().String(), | ||||
| 	}, nil | ||||
| @@ -586,45 +97,58 @@ func (h *httpTransport) Listen(addr string, opts ...ListenOption) (Listener, err | ||||
| 		err  error | ||||
| 	) | ||||
|  | ||||
| 	if listener := getNetListener(&options); listener != nil { | ||||
| 		fn := func(addr string) (net.Listener, error) { | ||||
| 	switch listener := getNetListener(&options); { | ||||
| 	// Extracted listener from context | ||||
| 	case listener != nil: | ||||
| 		getList := func(addr string) (net.Listener, error) { | ||||
| 			return listener, nil | ||||
| 		} | ||||
|  | ||||
| 		list, err = mnet.Listen(addr, fn) | ||||
| 	} else if h.opts.Secure || h.opts.TLSConfig != nil { | ||||
| 		list, err = mnet.Listen(addr, getList) | ||||
|  | ||||
| 	// Needs to create self signed certificate | ||||
| 	case h.opts.Secure || h.opts.TLSConfig != nil: | ||||
| 		config := h.opts.TLSConfig | ||||
|  | ||||
| 		fn := func(addr string) (net.Listener, error) { | ||||
| 			if config == nil { | ||||
| 				hosts := []string{addr} | ||||
|  | ||||
| 				// check if its a valid host:port | ||||
| 				if host, _, err := net.SplitHostPort(addr); err == nil { | ||||
| 					if len(host) == 0 { | ||||
| 						hosts = maddr.IPs() | ||||
| 					} else { | ||||
| 						hosts = []string{host} | ||||
| 					} | ||||
| 				} | ||||
|  | ||||
| 				// generate a certificate | ||||
| 				cert, err := mls.Certificate(hosts...) | ||||
| 				if err != nil { | ||||
| 					return nil, err | ||||
| 				} | ||||
| 				config = &tls.Config{Certificates: []tls.Certificate{cert}} | ||||
| 		getList := func(addr string) (net.Listener, error) { | ||||
| 			if config != nil { | ||||
| 				return tls.Listen("tcp", addr, config) | ||||
| 			} | ||||
|  | ||||
| 			hosts := []string{addr} | ||||
|  | ||||
| 			// check if its a valid host:port | ||||
| 			if host, _, err := net.SplitHostPort(addr); err == nil { | ||||
| 				if len(host) == 0 { | ||||
| 					hosts = maddr.IPs() | ||||
| 				} else { | ||||
| 					hosts = []string{host} | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			// generate a certificate | ||||
| 			cert, err := mls.Certificate(hosts...) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			config = &tls.Config{ | ||||
| 				Certificates: []tls.Certificate{cert}, | ||||
| 				MinVersion:   tls.VersionTLS12, | ||||
| 			} | ||||
|  | ||||
| 			return tls.Listen("tcp", addr, config) | ||||
| 		} | ||||
|  | ||||
| 		list, err = mnet.Listen(addr, fn) | ||||
| 	} else { | ||||
| 		fn := func(addr string) (net.Listener, error) { | ||||
| 		list, err = mnet.Listen(addr, getList) | ||||
|  | ||||
| 	// Create new basic net listener | ||||
| 	default: | ||||
| 		getList := func(addr string) (net.Listener, error) { | ||||
| 			return net.Listen("tcp", addr) | ||||
| 		} | ||||
|  | ||||
| 		list, err = mnet.Listen(addr, fn) | ||||
| 		list, err = mnet.Listen(addr, getList) | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| @@ -637,14 +161,6 @@ func (h *httpTransport) Listen(addr string, opts ...ListenOption) (Listener, err | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransport) Init(opts ...Option) error { | ||||
| 	for _, o := range opts { | ||||
| 		o(&h.opts) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *httpTransport) Options() Options { | ||||
| 	return h.opts | ||||
| } | ||||
| @@ -652,12 +168,3 @@ func (h *httpTransport) Options() Options { | ||||
| func (h *httpTransport) String() string { | ||||
| 	return "http" | ||||
| } | ||||
|  | ||||
| func NewHTTPTransport(opts ...Option) *httpTransport { | ||||
| 	var options Options | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
|  | ||||
| 	return &httpTransport{opts: options} | ||||
| } | ||||
|   | ||||
| @@ -10,6 +10,10 @@ import ( | ||||
| 	"go-micro.dev/v4/logger" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	DefaultBufSizeH2 = 4 * 1024 * 1024 | ||||
| ) | ||||
|  | ||||
| type Options struct { | ||||
| 	// Addrs is the list of intermediary addresses to connect to | ||||
| 	Addrs []string | ||||
| @@ -30,6 +34,8 @@ type Options struct { | ||||
| 	Context context.Context | ||||
| 	// Logger is the underline logger | ||||
| 	Logger logger.Logger | ||||
| 	// BuffSizeH2 is the HTTP2 buffer size | ||||
| 	BuffSizeH2 int | ||||
| } | ||||
|  | ||||
| type DialOptions struct { | ||||
| @@ -38,6 +44,10 @@ type DialOptions struct { | ||||
| 	Stream bool | ||||
| 	// Timeout for dialing | ||||
| 	Timeout time.Duration | ||||
| 	// ConnClose sets the Connection header to close | ||||
| 	ConnClose bool | ||||
| 	// InsecureSkipVerify skip TLS verification. | ||||
| 	InsecureSkipVerify bool | ||||
|  | ||||
| 	// TODO: add tls options when dialing | ||||
| 	// Currently set in global options | ||||
| @@ -106,22 +116,46 @@ func WithTimeout(d time.Duration) DialOption { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithLogger sets the underline logger. | ||||
| func WithLogger(l logger.Logger) Option { | ||||
| // WithConnClose sets the Connection header to close. | ||||
| func WithConnClose() DialOption { | ||||
| 	return func(o *DialOptions) { | ||||
| 		o.ConnClose = true | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func WithInsecureSkipVerify(b bool) DialOption { | ||||
| 	return func(o *DialOptions) { | ||||
| 		o.InsecureSkipVerify = b | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Logger sets the underline logger. | ||||
| func Logger(l logger.Logger) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.Logger = l | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // BuffSizeH2 sets the HTTP2 buffer size. | ||||
| // Default is 4 * 1024 * 1024. | ||||
| func BuffSizeH2(size int) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.BuffSizeH2 = size | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // InsecureSkipVerify sets the TLS options to skip verification. | ||||
| // NetListener Set net.Listener for httpTransport. | ||||
| func NetListener(customListener net.Listener) ListenOption { | ||||
| 	return func(o *ListenOptions) { | ||||
| 		if customListener == nil { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if o.Context == nil { | ||||
| 			o.Context = context.TODO() | ||||
| 		} | ||||
|  | ||||
| 		o.Context = context.WithValue(o.Context, netListener{}, customListener) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -1,55 +1,30 @@ | ||||
| // addr provides functions to retrieve local IP addresses from device interfaces. | ||||
| package addr | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net" | ||||
|  | ||||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	privateBlocks []*net.IPNet | ||||
| 	// ErrIPNotFound no IP address found, and explicit IP not provided. | ||||
| 	ErrIPNotFound = errors.New("no IP address found, and explicit IP not provided") | ||||
| ) | ||||
|  | ||||
| func init() { | ||||
| 	for _, b := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "100.64.0.0/10", "fd00::/8"} { | ||||
| 		if _, block, err := net.ParseCIDR(b); err == nil { | ||||
| 			privateBlocks = append(privateBlocks, block) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // AppendPrivateBlocks append private network blocks. | ||||
| func AppendPrivateBlocks(bs ...string) { | ||||
| 	for _, b := range bs { | ||||
| 		if _, block, err := net.ParseCIDR(b); err == nil { | ||||
| 			privateBlocks = append(privateBlocks, block) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func isPrivateIP(ipAddr string) bool { | ||||
| 	ip := net.ParseIP(ipAddr) | ||||
| 	for _, priv := range privateBlocks { | ||||
| 		if priv.Contains(ip) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // IsLocal tells us whether an ip is local. | ||||
| // IsLocal checks whether an IP belongs to one of the device's interfaces. | ||||
| func IsLocal(addr string) bool { | ||||
| 	// extract the host | ||||
| 	// Extract the host | ||||
| 	host, _, err := net.SplitHostPort(addr) | ||||
| 	if err == nil { | ||||
| 		addr = host | ||||
| 	} | ||||
|  | ||||
| 	// check if its localhost | ||||
| 	if addr == "localhost" { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	// check against all local ips | ||||
| 	// Check against all local ips | ||||
| 	for _, ip := range IPs() { | ||||
| 		if addr == ip { | ||||
| 			return true | ||||
| @@ -59,79 +34,53 @@ func IsLocal(addr string) bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // Extract returns a real ip. | ||||
| // Extract returns a valid IP address. If the address provided is a valid | ||||
| // address, it will be returned directly. Otherwise the available interfaces | ||||
| // be itterated over to find an IP address, prefferably private. | ||||
| func Extract(addr string) (string, error) { | ||||
| 	// if addr specified then its returned | ||||
| 	// if addr is already specified then it's directly returned | ||||
| 	if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]" && addr != "::") { | ||||
| 		return addr, nil | ||||
| 	} | ||||
|  | ||||
| 	var ( | ||||
| 		addrs   []net.Addr | ||||
| 		loAddrs []net.Addr | ||||
| 	) | ||||
|  | ||||
| 	ifaces, err := net.Interfaces() | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("Failed to get interfaces! Err: %v", err) | ||||
| 		return "", errors.Wrap(err, "failed to get interfaces") | ||||
| 	} | ||||
|  | ||||
| 	var addrs []net.Addr | ||||
| 	var loAddrs []net.Addr | ||||
| 	for _, iface := range ifaces { | ||||
| 		ifaceAddrs, err := iface.Addrs() | ||||
| 		if err != nil { | ||||
| 			// ignore error, interface can disappear from system | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if iface.Flags&net.FlagLoopback != 0 { | ||||
| 			loAddrs = append(loAddrs, ifaceAddrs...) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		addrs = append(addrs, ifaceAddrs...) | ||||
| 	} | ||||
|  | ||||
| 	// Add loopback addresses to the end of the list | ||||
| 	addrs = append(addrs, loAddrs...) | ||||
|  | ||||
| 	var ipAddr string | ||||
| 	var publicIP string | ||||
|  | ||||
| 	for _, rawAddr := range addrs { | ||||
| 		var ip net.IP | ||||
| 		switch addr := rawAddr.(type) { | ||||
| 		case *net.IPAddr: | ||||
| 			ip = addr.IP | ||||
| 		case *net.IPNet: | ||||
| 			ip = addr.IP | ||||
| 		default: | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if !isPrivateIP(ip.String()) { | ||||
| 			publicIP = ip.String() | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		ipAddr = ip.String() | ||||
| 		break | ||||
| 	// Try to find private IP in list, public IP otherwise | ||||
| 	ip, err := findIP(addrs) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	// return private ip | ||||
| 	if len(ipAddr) > 0 { | ||||
| 		a := net.ParseIP(ipAddr) | ||||
| 		if a == nil { | ||||
| 			return "", fmt.Errorf("ip addr %s is invalid", ipAddr) | ||||
| 		} | ||||
| 		return a.String(), nil | ||||
| 	} | ||||
|  | ||||
| 	// return public or virtual ip | ||||
| 	if len(publicIP) > 0 { | ||||
| 		a := net.ParseIP(publicIP) | ||||
| 		if a == nil { | ||||
| 			return "", fmt.Errorf("ip addr %s is invalid", publicIP) | ||||
| 		} | ||||
| 		return a.String(), nil | ||||
| 	} | ||||
|  | ||||
| 	return "", fmt.Errorf("No IP address found, and explicit IP not provided") | ||||
| 	return ip.String(), nil | ||||
| } | ||||
|  | ||||
| // IPs returns all known ips. | ||||
| // IPs returns all available interface IP addresses. | ||||
| func IPs() []string { | ||||
| 	ifaces, err := net.Interfaces() | ||||
| 	if err != nil { | ||||
| @@ -159,17 +108,42 @@ func IPs() []string { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// dont skip ipv6 addrs | ||||
| 			/* | ||||
| 				ip = ip.To4() | ||||
| 				if ip == nil { | ||||
| 					continue | ||||
| 				} | ||||
| 			*/ | ||||
|  | ||||
| 			ipAddrs = append(ipAddrs, ip.String()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return ipAddrs | ||||
| } | ||||
|  | ||||
| // findIP will return the first private IP available in the list, | ||||
| // if no private IP is available it will return a public IP if present. | ||||
| func findIP(addresses []net.Addr) (net.IP, error) { | ||||
| 	var publicIP net.IP | ||||
|  | ||||
| 	for _, rawAddr := range addresses { | ||||
| 		var ip net.IP | ||||
| 		switch addr := rawAddr.(type) { | ||||
| 		case *net.IPAddr: | ||||
| 			ip = addr.IP | ||||
| 		case *net.IPNet: | ||||
| 			ip = addr.IP | ||||
| 		default: | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if !ip.IsPrivate() { | ||||
| 			publicIP = ip | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// Return private IP if available | ||||
| 		return ip, nil | ||||
| 	} | ||||
|  | ||||
| 	// Return public or virtual IP | ||||
| 	if len(publicIP) > 0 { | ||||
| 		return publicIP, nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, ErrIPNotFound | ||||
| } | ||||
|   | ||||
| @@ -54,24 +54,3 @@ func TestExtractor(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestAppendPrivateBlocks(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		addr   string | ||||
| 		expect bool | ||||
| 	}{ | ||||
| 		{addr: "9.134.71.34", expect: true}, | ||||
| 		{addr: "8.10.110.34", expect: false}, // not in private blocks | ||||
| 	} | ||||
|  | ||||
| 	AppendPrivateBlocks("9.134.0.0/16") | ||||
|  | ||||
| 	for _, test := range tests { | ||||
| 		t.Run(test.addr, func(t *testing.T) { | ||||
| 			res := isPrivateIP(test.addr) | ||||
| 			if res != test.expect { | ||||
| 				t.Fatalf("expected %t got %t", test.expect, res) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/google/uuid" | ||||
|  | ||||
| 	"go-micro.dev/v4/transport" | ||||
| ) | ||||
|  | ||||
| @@ -34,14 +35,21 @@ func newPool(options Options) *pool { | ||||
|  | ||||
| func (p *pool) Close() error { | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	var err error | ||||
|  | ||||
| 	for k, c := range p.conns { | ||||
| 		for _, conn := range c { | ||||
| 			conn.Client.Close() | ||||
| 			if nerr := conn.Client.Close(); nerr != nil { | ||||
| 				err = nerr | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		delete(p.conns, k) | ||||
| 	} | ||||
| 	p.Unlock() | ||||
| 	return nil | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // NoOp the Close since we manage it. | ||||
| @@ -61,20 +69,24 @@ func (p *pool) Get(addr string, opts ...transport.DialOption) (Conn, error) { | ||||
| 	p.Lock() | ||||
| 	conns := p.conns[addr] | ||||
|  | ||||
| 	// while we have conns check age and then return one | ||||
| 	// While we have conns check age and then return one | ||||
| 	// otherwise we'll create a new conn | ||||
| 	for len(conns) > 0 { | ||||
| 		conn := conns[len(conns)-1] | ||||
| 		conns = conns[:len(conns)-1] | ||||
| 		p.conns[addr] = conns | ||||
|  | ||||
| 		// if conn is old kill it and move on | ||||
| 		// If conn is old kill it and move on | ||||
| 		if d := time.Since(conn.Created()); d > p.ttl { | ||||
| 			conn.Client.Close() | ||||
| 			if err := conn.Client.Close(); err != nil { | ||||
| 				p.Unlock() | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// we got a good conn, lets unlock and return it | ||||
| 		// We got a good conn, lets unlock and return it | ||||
| 		p.Unlock() | ||||
|  | ||||
| 		return conn, nil | ||||
| @@ -87,6 +99,7 @@ func (p *pool) Get(addr string, opts ...transport.DialOption) (Conn, error) { | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &poolConn{ | ||||
| 		Client:  c, | ||||
| 		id:      uuid.New().String(), | ||||
| @@ -102,13 +115,14 @@ func (p *pool) Release(conn Conn, err error) error { | ||||
|  | ||||
| 	// otherwise put it back for reuse | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	conns := p.conns[conn.Remote()] | ||||
| 	if len(conns) >= p.size { | ||||
| 		p.Unlock() | ||||
| 		return conn.(*poolConn).Client.Close() | ||||
| 	} | ||||
|  | ||||
| 	p.conns[conn.Remote()] = append(conns, conn.(*poolConn)) | ||||
| 	p.Unlock() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -13,10 +13,11 @@ type Pool interface { | ||||
| 	Close() error | ||||
| 	// Get a connection | ||||
| 	Get(addr string, opts ...transport.DialOption) (Conn, error) | ||||
| 	// Releaes the connection | ||||
| 	// Release the connection | ||||
| 	Release(c Conn, status error) error | ||||
| } | ||||
|  | ||||
| // Conn interface represents a pool connection. | ||||
| type Conn interface { | ||||
| 	// unique id of connection | ||||
| 	Id() string | ||||
| @@ -26,10 +27,12 @@ type Conn interface { | ||||
| 	transport.Client | ||||
| } | ||||
|  | ||||
| // NewPool will return a new pool object. | ||||
| func NewPool(opts ...Option) Pool { | ||||
| 	var options Options | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
|  | ||||
| 	return newPool(options) | ||||
| } | ||||
|   | ||||
| @@ -5,7 +5,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// mock registry data. | ||||
| 	// Data is a set of mock registry data. | ||||
| 	Data = map[string][]*registry.Service{ | ||||
| 		"foo": { | ||||
| 			{ | ||||
| @@ -45,3 +45,14 @@ var ( | ||||
| 		}, | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| // EmptyChannel will empty out a error channel by checking if an error is | ||||
| // present, and if so return the error. | ||||
| func EmptyChannel(c chan error) error { | ||||
| 	select { | ||||
| 	case err := <-c: | ||||
| 		return err | ||||
| 	default: | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"go-micro.dev/v4/debug/trace" | ||||
| 	"go-micro.dev/v4/metadata" | ||||
| 	"go-micro.dev/v4/server" | ||||
| 	"go-micro.dev/v4/transport/headers" | ||||
| ) | ||||
|  | ||||
| type fromServiceWrapper struct { | ||||
| @@ -19,10 +20,6 @@ type fromServiceWrapper struct { | ||||
| 	headers metadata.Metadata | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	HeaderPrefix = "Micro-" | ||||
| ) | ||||
|  | ||||
| func (f *fromServiceWrapper) setHeaders(ctx context.Context) context.Context { | ||||
| 	// don't overwrite keys | ||||
| 	return metadata.MergeContext(ctx, f.headers, false) | ||||
| @@ -48,7 +45,7 @@ func FromService(name string, c client.Client) client.Client { | ||||
| 	return &fromServiceWrapper{ | ||||
| 		c, | ||||
| 		metadata.Metadata{ | ||||
| 			HeaderPrefix + "From-Service": name, | ||||
| 			headers.Prefix + "From-Service": name, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| @@ -159,8 +156,8 @@ func (a *authWrapper) Call(ctx context.Context, req client.Request, rsp interfac | ||||
| 	} | ||||
|  | ||||
| 	// set the namespace header if it has not been set (e.g. on a service to service request) | ||||
| 	if _, ok := metadata.Get(ctx, "Micro-Namespace"); !ok { | ||||
| 		ctx = metadata.Set(ctx, "Micro-Namespace", aa.Options().Namespace) | ||||
| 	if _, ok := metadata.Get(ctx, headers.Namespace); !ok { | ||||
| 		ctx = metadata.Set(ctx, headers.Namespace, aa.Options().Namespace) | ||||
| 	} | ||||
|  | ||||
| 	// check to see if we have a valid access token | ||||
|   | ||||
		Reference in New Issue
	
	Block a user