mirror of
https://github.com/ManyakRus/crud_generator.git
synced 2025-01-23 09:24:43 +02:00
958 lines
23 KiB
Go
958 lines
23 KiB
Go
package nrpc
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"runtime/debug"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
jsonpb "google.golang.org/protobuf/encoding/protojson"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
const (
|
|
statusHeader = "Status"
|
|
noResponderStatus = "503"
|
|
)
|
|
|
|
// ContextKey type for storing values into context.Context
|
|
type ContextKey int
|
|
|
|
// ErrStreamInvalidMsgCount is when a stream reply gets a wrong number of messages
|
|
var ErrStreamInvalidMsgCount = errors.New("Stream reply received an incorrect number of messages")
|
|
|
|
//go:generate protoc --go_out=. --go_opt=paths=source_relative nrpc.proto
|
|
|
|
type NatsConn interface {
|
|
Publish(subj string, data []byte) error
|
|
PublishRequest(subj, reply string, data []byte) error
|
|
Request(subj string, data []byte, timeout time.Duration) (*nats.Msg, error)
|
|
|
|
ChanSubscribe(subj string, ch chan *nats.Msg) (*nats.Subscription, error)
|
|
Subscribe(subj string, handler nats.MsgHandler) (*nats.Subscription, error)
|
|
SubscribeSync(subj string) (*nats.Subscription, error)
|
|
}
|
|
|
|
// ReplyInboxMaker returns a new inbox subject for a given nats connection.
|
|
type ReplyInboxMaker func(NatsConn) string
|
|
|
|
// GetReplyInbox is used by StreamCall to get a inbox subject
|
|
// It can be changed by a client lib that needs custom inbox subjects
|
|
var GetReplyInbox ReplyInboxMaker = func(NatsConn) string {
|
|
return nats.NewInbox()
|
|
}
|
|
|
|
func (e *Error) Error() string {
|
|
return fmt.Sprintf("%s error: %s", Error_Type_name[int32(e.Type)], e.Message)
|
|
}
|
|
|
|
func Unmarshal(encoding string, data []byte, msg proto.Message) error {
|
|
switch encoding {
|
|
case "protobuf":
|
|
return proto.Unmarshal(data, msg)
|
|
case "json":
|
|
return jsonpb.Unmarshal(data, msg)
|
|
default:
|
|
return errors.New("Invalid encoding: " + encoding)
|
|
}
|
|
}
|
|
|
|
func UnmarshalResponse(encoding string, data []byte, msg proto.Message) error {
|
|
switch encoding {
|
|
case "protobuf":
|
|
if len(data) > 0 && data[0] == 0 {
|
|
var repErr Error
|
|
if err := proto.Unmarshal(data[1:], &repErr); err != nil {
|
|
return err
|
|
}
|
|
return &repErr
|
|
}
|
|
return proto.Unmarshal(data, msg)
|
|
case "json":
|
|
if len(data) > 13 && bytes.Equal(data[:13], []byte("{\"__error__\":")) {
|
|
var rep map[string]json.RawMessage
|
|
if err := json.Unmarshal(data, &rep); err != nil {
|
|
return err
|
|
}
|
|
errbuf, ok := rep["__error__"]
|
|
if !ok {
|
|
panic("invalid error message")
|
|
}
|
|
var nrpcErr Error
|
|
if err := jsonpb.Unmarshal(errbuf, &nrpcErr); err != nil {
|
|
return err
|
|
}
|
|
return &nrpcErr
|
|
}
|
|
return jsonpb.Unmarshal(data, msg)
|
|
default:
|
|
return errors.New("Invalid encoding: " + encoding)
|
|
}
|
|
}
|
|
|
|
func Marshal(encoding string, msg proto.Message) ([]byte, error) {
|
|
switch encoding {
|
|
case "protobuf":
|
|
return proto.Marshal(msg)
|
|
case "json":
|
|
return jsonpb.Marshal(msg)
|
|
default:
|
|
return nil, errors.New("Invalid encoding: " + encoding)
|
|
}
|
|
}
|
|
|
|
func MarshalErrorResponse(encoding string, repErr *Error) ([]byte, error) {
|
|
switch encoding {
|
|
case "protobuf":
|
|
b, err := proto.Marshal(repErr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return append([]byte{0}, b...), nil
|
|
case "json":
|
|
b, err := jsonpb.Marshal(repErr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return json.Marshal(map[string]json.RawMessage{
|
|
"__error__": json.RawMessage(b),
|
|
})
|
|
default:
|
|
return nil, errors.New("Invalid encoding: " + encoding)
|
|
}
|
|
}
|
|
|
|
func ParseSubject(
|
|
packageSubject string, packageParamsCount int,
|
|
serviceSubject string, serviceParamsCount int,
|
|
subject string,
|
|
) (packageParams []string, serviceParams []string,
|
|
name string, tail []string, err error,
|
|
) {
|
|
packageSubjectDepth := 0
|
|
if packageSubject != "" {
|
|
packageSubjectDepth = strings.Count(packageSubject, ".") + 1
|
|
}
|
|
serviceSubjectDepth := strings.Count(serviceSubject, ".") + 1
|
|
subjectMinSize := packageSubjectDepth + packageParamsCount + serviceSubjectDepth + serviceParamsCount + 1
|
|
|
|
tokens := strings.Split(subject, ".")
|
|
if len(tokens) < subjectMinSize {
|
|
err = fmt.Errorf(
|
|
"Invalid subject len. Expects number of parts >= %d, got %d",
|
|
subjectMinSize, len(tokens))
|
|
return
|
|
}
|
|
if packageSubject != "" {
|
|
for i, packageSubjectPart := range strings.Split(packageSubject, ".") {
|
|
if tokens[i] != packageSubjectPart {
|
|
err = fmt.Errorf(
|
|
"Invalid subject prefix. Expected '%s', got '%s'",
|
|
packageSubjectPart, tokens[i])
|
|
return
|
|
}
|
|
}
|
|
tokens = tokens[packageSubjectDepth:]
|
|
}
|
|
|
|
packageParams = tokens[0:packageParamsCount]
|
|
tokens = tokens[packageParamsCount:]
|
|
|
|
for i, serviceSubjectPart := range strings.Split(serviceSubject, ".") {
|
|
if tokens[i] != serviceSubjectPart {
|
|
err = fmt.Errorf(
|
|
"Invalid subject. Service should be '%s', got '%s'",
|
|
serviceSubjectPart, tokens[i])
|
|
return
|
|
}
|
|
}
|
|
tokens = tokens[serviceSubjectDepth:]
|
|
|
|
serviceParams = tokens[0:serviceParamsCount]
|
|
tokens = tokens[serviceParamsCount:]
|
|
|
|
name = tokens[0]
|
|
tokens = tokens[1:]
|
|
|
|
tail = tokens
|
|
return
|
|
}
|
|
|
|
func ParseSubjectTail(
|
|
methodParamsCount int,
|
|
tail []string,
|
|
) (
|
|
methodParams []string, encoding string, err error,
|
|
) {
|
|
if len(tail) < methodParamsCount || len(tail) > methodParamsCount+1 {
|
|
err = fmt.Errorf(
|
|
"Invalid subject tail length. Expects %d or %d parts, got %d",
|
|
methodParamsCount, methodParamsCount+1, len(tail),
|
|
)
|
|
return
|
|
}
|
|
methodParams = tail[:methodParamsCount]
|
|
tail = tail[methodParamsCount:]
|
|
switch len(tail) {
|
|
case 0:
|
|
encoding = "protobuf"
|
|
case 1:
|
|
encoding = tail[0]
|
|
default:
|
|
panic("Got extra tokens, which should be impossible at this point")
|
|
}
|
|
return
|
|
}
|
|
|
|
func Call(req proto.Message, rep proto.Message, nc NatsConn, subject string, encoding string, timeout time.Duration) error {
|
|
// encode request
|
|
rawRequest, err := Marshal(encoding, req)
|
|
if err != nil {
|
|
log.Printf("nrpc: inner request marshal failed: %v", err)
|
|
return err
|
|
}
|
|
|
|
if encoding != "protobuf" {
|
|
subject += "." + encoding
|
|
}
|
|
|
|
// call
|
|
if _, noreply := rep.(*NoReply); noreply {
|
|
err := nc.Publish(subject, rawRequest)
|
|
if err != nil {
|
|
log.Printf("nrpc: nats publish failed: %v", err)
|
|
}
|
|
return err
|
|
}
|
|
msg, err := nc.Request(subject, rawRequest, timeout)
|
|
if err != nil {
|
|
log.Printf("nrpc: nats request failed: %v", err)
|
|
return err
|
|
}
|
|
|
|
data := msg.Data
|
|
|
|
if err := UnmarshalResponse(encoding, data, rep); err != nil {
|
|
if _, isError := err.(*Error); !isError {
|
|
log.Printf("nrpc: response unmarshal failed: %v", err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func Poll(
|
|
req proto.Message, rep proto.Message,
|
|
nc NatsConn, subject string, encoding string, timeout time.Duration,
|
|
maxreplies int, cb func() error,
|
|
) error {
|
|
// encode request
|
|
rawRequest, err := Marshal(encoding, req)
|
|
if err != nil {
|
|
log.Printf("nrpc: inner request marshal failed: %v", err)
|
|
return err
|
|
}
|
|
|
|
if encoding != "protobuf" {
|
|
subject += "." + encoding
|
|
}
|
|
|
|
reply := GetReplyInbox(nc)
|
|
replyC := make(chan *nats.Msg)
|
|
defer close(replyC)
|
|
|
|
sub, err := nc.ChanSubscribe(reply, replyC)
|
|
defer func() {
|
|
if err := sub.Unsubscribe(); err != nil {
|
|
log.Printf("nrpc: nats unsubscribe failed: %v", err)
|
|
}
|
|
}()
|
|
|
|
if err := nc.PublishRequest(subject, reply, rawRequest); err != nil {
|
|
log.Printf("nrpc: nats request failed: %v", err)
|
|
return err
|
|
}
|
|
|
|
timeoutC := time.After(timeout)
|
|
var replyCount int
|
|
|
|
for {
|
|
select {
|
|
case msg := <-replyC:
|
|
replyCount++
|
|
|
|
data := msg.Data
|
|
|
|
if err := UnmarshalResponse(encoding, data, rep); err != nil {
|
|
if _, isError := err.(*Error); !isError {
|
|
log.Printf("nrpc: response unmarshal failed: %v", err)
|
|
}
|
|
return err
|
|
}
|
|
if err := cb(); err != nil {
|
|
return err
|
|
}
|
|
if replyCount == maxreplies {
|
|
return nil
|
|
}
|
|
case <-timeoutC:
|
|
return nats.ErrTimeout
|
|
}
|
|
}
|
|
}
|
|
|
|
const (
|
|
// RequestContextKey is the key for string the request into the context
|
|
RequestContextKey ContextKey = iota
|
|
)
|
|
|
|
// NewRequest creates a Request instance
|
|
func NewRequest(ctx context.Context, conn NatsConn, subject string, replySubject string) *Request {
|
|
return &Request{
|
|
Context: ctx,
|
|
Conn: conn,
|
|
Subject: subject,
|
|
ReplySubject: replySubject,
|
|
CreatedAt: time.Now(),
|
|
}
|
|
}
|
|
|
|
// GetRequest returns the Request associated with a context, or nil if absent
|
|
func GetRequest(ctx context.Context) *Request {
|
|
request, _ := ctx.Value(RequestContextKey).(*Request)
|
|
return request
|
|
}
|
|
|
|
// Request is a server-side incoming request
|
|
type Request struct {
|
|
Context context.Context
|
|
Conn NatsConn
|
|
|
|
isStreamedReply bool
|
|
KeepStreamAlive *KeepStreamAlive
|
|
StreamContext context.Context
|
|
StreamCancel func()
|
|
StreamMsgCount uint32
|
|
streamLock sync.Mutex
|
|
|
|
Subject string
|
|
MethodName string
|
|
SubjectTail []string
|
|
|
|
CreatedAt time.Time
|
|
StartedAt time.Time
|
|
|
|
Encoding string
|
|
NoReply bool
|
|
ReplySubject string
|
|
|
|
PackageParams map[string]string
|
|
ServiceParams map[string]string
|
|
|
|
AfterReply func(r *Request, success bool, replySuccess bool)
|
|
|
|
Handler func(context.Context) (proto.Message, error)
|
|
}
|
|
|
|
// Elapsed duration since request was started
|
|
func (r *Request) Elapsed() time.Duration {
|
|
return time.Since(r.CreatedAt)
|
|
}
|
|
|
|
// Run the handler and capture any error. Returns the response or the error
|
|
// that should be returned to the caller
|
|
func (r *Request) Run() (msg proto.Message, replyError *Error) {
|
|
r.StartedAt = time.Now()
|
|
ctx := r.Context
|
|
if r.StreamedReply() {
|
|
ctx = r.StreamContext
|
|
}
|
|
ctx = context.WithValue(ctx, RequestContextKey, r)
|
|
msg, replyError = CaptureErrors(
|
|
func() (proto.Message, error) {
|
|
return r.Handler(ctx)
|
|
})
|
|
return
|
|
}
|
|
|
|
// RunAndReply calls Run() and send the reply back to the caller
|
|
func (r *Request) RunAndReply() {
|
|
var failed, replyFailed bool
|
|
// In case RunAndReply was called directly, we may need to initialize the
|
|
// streamed reply
|
|
r.setupStreamedReply()
|
|
resp, replyError := r.Run()
|
|
if replyError != nil {
|
|
failed = true
|
|
log.Printf("%s handler failed: %s", r.MethodName, replyError)
|
|
}
|
|
if !r.NoReply {
|
|
if err := r.SendReply(resp, replyError); err != nil {
|
|
replyFailed = true
|
|
log.Printf("%s failed to publish the response: %s", r.MethodName, err)
|
|
}
|
|
}
|
|
if r.AfterReply != nil {
|
|
r.AfterReply(r, !failed, !replyFailed)
|
|
}
|
|
}
|
|
|
|
// PackageParam returns a package parameter value, or "" if absent
|
|
func (r *Request) PackageParam(key string) string {
|
|
if r == nil || r.PackageParams == nil {
|
|
return ""
|
|
}
|
|
return r.PackageParams[key]
|
|
}
|
|
|
|
// ServiceParam returns a package parameter value, or "" if absent
|
|
func (r *Request) ServiceParam(key string) string {
|
|
if r == nil || r.ServiceParams == nil {
|
|
return ""
|
|
}
|
|
return r.ServiceParams[key]
|
|
}
|
|
|
|
// SetPackageParam sets a package param value
|
|
func (r *Request) SetPackageParam(key, value string) {
|
|
if r.PackageParams == nil {
|
|
r.PackageParams = make(map[string]string)
|
|
}
|
|
r.PackageParams[key] = value
|
|
}
|
|
|
|
// SetServiceParam sets a service param value
|
|
func (r *Request) SetServiceParam(key, value string) {
|
|
if r.ServiceParams == nil {
|
|
r.ServiceParams = make(map[string]string)
|
|
}
|
|
r.ServiceParams[key] = value
|
|
}
|
|
|
|
// EnableStreamedReply enables the streamed reply mode
|
|
func (r *Request) EnableStreamedReply() {
|
|
r.isStreamedReply = true
|
|
}
|
|
|
|
// setupStreamedReply initializes the reply stream if needed.
|
|
func (r *Request) setupStreamedReply() {
|
|
r.streamLock.Lock()
|
|
defer r.streamLock.Unlock()
|
|
|
|
if !r.StreamedReply() || r.KeepStreamAlive != nil {
|
|
return
|
|
}
|
|
r.StreamContext, r.StreamCancel = context.WithCancel(r.Context)
|
|
r.KeepStreamAlive = NewKeepStreamAlive(
|
|
r.Conn, r.ReplySubject, r.Encoding, r.StreamCancel)
|
|
}
|
|
|
|
// StreamedReply returns true if the request reply is streamed
|
|
func (r *Request) StreamedReply() bool {
|
|
return r.isStreamedReply
|
|
}
|
|
|
|
// SendStreamReply send a reply a part of a stream
|
|
func (r *Request) SendStreamReply(msg proto.Message) {
|
|
if err := r.sendReply(msg, nil); err != nil {
|
|
log.Printf("nrpc: error publishing response: %s", err)
|
|
r.StreamCancel()
|
|
return
|
|
}
|
|
r.StreamMsgCount++
|
|
}
|
|
|
|
// SendReply sends a reply to the caller
|
|
func (r *Request) SendReply(resp proto.Message, withError *Error) error {
|
|
if r.StreamedReply() {
|
|
r.KeepStreamAlive.Stop()
|
|
if withError == nil {
|
|
return r.sendReply(
|
|
nil, &Error{Type: Error_EOS, MsgCount: r.StreamMsgCount},
|
|
)
|
|
}
|
|
}
|
|
return r.sendReply(resp, withError)
|
|
}
|
|
|
|
// sendReply sends a reply to the caller
|
|
func (r *Request) sendReply(resp proto.Message, withError *Error) error {
|
|
return Publish(resp, withError, r.Conn, r.ReplySubject, r.Encoding)
|
|
}
|
|
|
|
// SendErrorTooBusy cancels the request with a 'SERVERTOOBUSY' error
|
|
func (r *Request) SendErrorTooBusy(msg string) error {
|
|
return r.SendReply(nil, &Error{
|
|
Type: Error_SERVERTOOBUSY,
|
|
Message: msg,
|
|
})
|
|
}
|
|
|
|
var ErrEOS = errors.New("End of stream")
|
|
var ErrCanceled = errors.New("Call canceled")
|
|
|
|
func NewStreamCallSubscription(
|
|
ctx context.Context, nc NatsConn, encoding string, subject string,
|
|
timeout time.Duration,
|
|
) (*StreamCallSubscription, error) {
|
|
sub := StreamCallSubscription{
|
|
ctx: ctx,
|
|
nc: nc,
|
|
encoding: encoding,
|
|
subject: subject,
|
|
timeout: timeout,
|
|
timeoutT: time.NewTimer(timeout),
|
|
closed: false,
|
|
subCh: make(chan *nats.Msg, 256),
|
|
nextCh: make(chan *nats.Msg),
|
|
quit: make(chan struct{}),
|
|
errCh: make(chan error, 1),
|
|
}
|
|
ssub, err := nc.ChanSubscribe(subject, sub.subCh)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go sub.loop(ssub)
|
|
return &sub, nil
|
|
}
|
|
|
|
type StreamCallSubscription struct {
|
|
ctx context.Context
|
|
nc NatsConn
|
|
encoding string
|
|
subject string
|
|
timeout time.Duration
|
|
timeoutT *time.Timer
|
|
closed bool
|
|
subCh chan *nats.Msg
|
|
nextCh chan *nats.Msg
|
|
quit chan struct{}
|
|
errCh chan error
|
|
msgCount uint32
|
|
}
|
|
|
|
func (sub *StreamCallSubscription) stop() {
|
|
close(sub.quit)
|
|
}
|
|
|
|
func (sub *StreamCallSubscription) loop(ssub *nats.Subscription) {
|
|
hbSubject := sub.subject + ".heartbeat"
|
|
ticker := time.NewTicker(time.Second)
|
|
defer ticker.Stop()
|
|
defer ssub.Unsubscribe()
|
|
for {
|
|
select {
|
|
case msg := <-sub.subCh:
|
|
sub.timeoutT.Reset(sub.timeout)
|
|
|
|
if len(msg.Data) == 1 && msg.Data[0] == 0 {
|
|
break
|
|
}
|
|
// Check for no responder status.
|
|
if len(msg.Data) == 0 && msg.Header.Get(statusHeader) == noResponderStatus {
|
|
sub.errCh <- nats.ErrNoResponders
|
|
return
|
|
}
|
|
|
|
sub.nextCh <- msg
|
|
case <-sub.timeoutT.C:
|
|
sub.errCh <- nats.ErrTimeout
|
|
return
|
|
case <-sub.ctx.Done():
|
|
// send a 'lastbeat' and quit
|
|
b, err := Marshal(sub.encoding, &HeartBeat{Lastbeat: true})
|
|
if err != nil {
|
|
err = fmt.Errorf("Error marshaling heartbeat: %s", err)
|
|
sub.errCh <- err
|
|
return
|
|
}
|
|
if err := sub.nc.Publish(hbSubject, b); err != nil {
|
|
err = fmt.Errorf("Error sending heartbeat: %s", err)
|
|
sub.errCh <- err
|
|
return
|
|
}
|
|
sub.errCh <- ErrCanceled
|
|
return
|
|
case <-ticker.C:
|
|
msg, err := Marshal(sub.encoding, &HeartBeat{})
|
|
if err != nil {
|
|
err = fmt.Errorf("Error marshaling heartbeat: %s", err)
|
|
sub.errCh <- err
|
|
return
|
|
}
|
|
if err := sub.nc.Publish(hbSubject, msg); err != nil {
|
|
err = fmt.Errorf("Error sending heartbeat: %s", err)
|
|
sub.errCh <- err
|
|
return
|
|
}
|
|
case <-sub.quit:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sub *StreamCallSubscription) Next(rep proto.Message) error {
|
|
if sub.closed {
|
|
return nats.ErrBadSubscription
|
|
}
|
|
select {
|
|
case err := <-sub.errCh:
|
|
sub.closed = true
|
|
return err
|
|
case msg := <-sub.nextCh:
|
|
if err := UnmarshalResponse(sub.encoding, msg.Data, rep); err != nil {
|
|
sub.stop()
|
|
sub.closed = true
|
|
if nrpcErr, ok := err.(*Error); ok {
|
|
if nrpcErr.GetMsgCount() != sub.msgCount {
|
|
log.Printf(
|
|
"nrpc: received invalid number of messages. Expected %d, got %d",
|
|
nrpcErr.GetMsgCount(), sub.msgCount)
|
|
}
|
|
if nrpcErr.GetType() == Error_EOS {
|
|
if nrpcErr.GetMsgCount() != sub.msgCount {
|
|
return ErrStreamInvalidMsgCount
|
|
}
|
|
return ErrEOS
|
|
}
|
|
} else {
|
|
log.Printf("nrpc: response unmarshal failed: %v", err)
|
|
}
|
|
return err
|
|
}
|
|
sub.msgCount++
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func StreamCall(ctx context.Context, nc NatsConn, subject string, req proto.Message, encoding string, timeout time.Duration) (*StreamCallSubscription, error) {
|
|
rawRequest, err := Marshal(encoding, req)
|
|
if err != nil {
|
|
log.Printf("nrpc: inner request marshal failed: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
if encoding != "protobuf" {
|
|
subject += "." + encoding
|
|
}
|
|
|
|
reply := GetReplyInbox(nc)
|
|
|
|
streamSub, err := NewStreamCallSubscription(ctx, nc, encoding, reply, timeout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := nc.PublishRequest(subject, reply, rawRequest); err != nil {
|
|
streamSub.stop()
|
|
return nil, err
|
|
}
|
|
return streamSub, nil
|
|
}
|
|
|
|
func Publish(resp proto.Message, withError *Error, nc NatsConn, subject string, encoding string) error {
|
|
var rawResponse []byte
|
|
var err error
|
|
|
|
if withError != nil {
|
|
rawResponse, err = MarshalErrorResponse(encoding, withError)
|
|
} else {
|
|
rawResponse, err = Marshal(encoding, resp)
|
|
}
|
|
|
|
if err != nil {
|
|
log.Printf("nrpc: rpc response marshal failed: %v", err)
|
|
return err
|
|
}
|
|
|
|
// send response
|
|
if err := nc.Publish(subject, rawResponse); err != nil {
|
|
log.Printf("nrpc: response publish failed: %v", err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CaptureErrors runs a handler and convert error and panics into proper Error
|
|
func CaptureErrors(fn func() (proto.Message, error)) (msg proto.Message, replyError *Error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Caught panic: %s\n%s", r, debug.Stack())
|
|
replyError = &Error{
|
|
Type: Error_SERVER,
|
|
Message: fmt.Sprint(r),
|
|
}
|
|
}
|
|
}()
|
|
var err error
|
|
msg, err = fn()
|
|
if err != nil {
|
|
var ok bool
|
|
if replyError, ok = err.(*Error); !ok {
|
|
replyError = &Error{
|
|
Type: Error_CLIENT,
|
|
Message: err.Error(),
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func NewKeepStreamAlive(nc NatsConn, subject string, encoding string, onError func()) *KeepStreamAlive {
|
|
k := KeepStreamAlive{
|
|
nc: nc,
|
|
subject: subject,
|
|
encoding: encoding,
|
|
c: make(chan struct{}),
|
|
onError: onError,
|
|
}
|
|
go k.loop()
|
|
return &k
|
|
}
|
|
|
|
type KeepStreamAlive struct {
|
|
nc NatsConn
|
|
subject string
|
|
encoding string
|
|
c chan struct{}
|
|
onError func()
|
|
}
|
|
|
|
func (k *KeepStreamAlive) Stop() {
|
|
close(k.c)
|
|
}
|
|
|
|
func (k *KeepStreamAlive) loop() {
|
|
hbChan := make(chan *nats.Msg, 256)
|
|
hbSub, err := k.nc.ChanSubscribe(k.subject+".heartbeat", hbChan)
|
|
if err != nil {
|
|
log.Printf("nrpc: could not subscribe to heartbeat: %s", err)
|
|
k.onError()
|
|
}
|
|
defer func() {
|
|
if err := hbSub.Unsubscribe(); err != nil {
|
|
log.Printf("nrpc: error unsubscribing from heartbeat: %s", err)
|
|
}
|
|
}()
|
|
hbDelay := 0
|
|
ticker := time.NewTicker(time.Second)
|
|
for {
|
|
select {
|
|
case msg := <-hbChan:
|
|
var hb HeartBeat
|
|
if err := Unmarshal(k.encoding, msg.Data, &hb); err != nil {
|
|
log.Printf("nrpc: error unmarshaling heartbeat: %s", err)
|
|
ticker.Stop()
|
|
k.onError()
|
|
return
|
|
}
|
|
if hb.Lastbeat {
|
|
log.Printf("nrpc: client canceled the streamed reply. (%s)", k.subject)
|
|
ticker.Stop()
|
|
k.onError()
|
|
return
|
|
}
|
|
hbDelay = 0
|
|
case <-ticker.C:
|
|
hbDelay++
|
|
if hbDelay >= 5 {
|
|
log.Printf("nrpc: No heartbeat received in 5 seconds. Canceling.")
|
|
ticker.Stop()
|
|
k.onError()
|
|
return
|
|
}
|
|
if err := k.nc.Publish(k.subject, []byte{0}); err != nil {
|
|
log.Printf("nrpc: error publishing response: %s", err)
|
|
ticker.Stop()
|
|
k.onError()
|
|
return
|
|
}
|
|
case <-k.c:
|
|
ticker.Stop()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// WorkerPool is a pool of workers
|
|
type WorkerPool struct {
|
|
Context context.Context
|
|
contextCancel context.CancelFunc
|
|
|
|
queue chan *Request
|
|
schedule chan *Request
|
|
waitGroup sync.WaitGroup
|
|
m sync.Mutex
|
|
|
|
size uint
|
|
maxPending uint
|
|
maxPendingDuration time.Duration
|
|
}
|
|
|
|
// NewWorkerPool creates a pool of workers
|
|
func NewWorkerPool(
|
|
ctx context.Context,
|
|
size uint,
|
|
maxPending uint,
|
|
maxPendingDuration time.Duration,
|
|
) *WorkerPool {
|
|
nCtx, cancel := context.WithCancel(ctx)
|
|
pool := WorkerPool{
|
|
Context: nCtx,
|
|
contextCancel: cancel,
|
|
queue: make(chan *Request, maxPending),
|
|
schedule: make(chan *Request),
|
|
maxPending: maxPending,
|
|
maxPendingDuration: maxPendingDuration,
|
|
}
|
|
pool.waitGroup.Add(1)
|
|
go pool.scheduler()
|
|
pool.SetSize(size)
|
|
return &pool
|
|
}
|
|
|
|
func (pool *WorkerPool) getQueue() (queue chan *Request) {
|
|
pool.m.Lock()
|
|
queue = pool.queue
|
|
pool.m.Unlock()
|
|
return
|
|
}
|
|
|
|
func (pool *WorkerPool) scheduler() {
|
|
defer pool.waitGroup.Done()
|
|
|
|
for {
|
|
queue := pool.getQueue()
|
|
if queue == nil {
|
|
return
|
|
}
|
|
queueLoop:
|
|
for request := range queue {
|
|
now := time.Now()
|
|
|
|
pool.m.Lock()
|
|
deadline := request.CreatedAt.Add(pool.maxPendingDuration)
|
|
pool.m.Unlock()
|
|
|
|
if deadline.After(now) {
|
|
// Safety call to setupStreamedReply in case QueueRequest had
|
|
// to time to do it yet
|
|
request.setupStreamedReply()
|
|
select {
|
|
case pool.schedule <- request:
|
|
continue queueLoop
|
|
case <-time.After(deadline.Sub(now)):
|
|
// Too late
|
|
}
|
|
}
|
|
request.SendErrorTooBusy("No worker available")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (pool *WorkerPool) worker() {
|
|
defer pool.waitGroup.Done()
|
|
for request := range pool.schedule {
|
|
if request == nil {
|
|
return
|
|
}
|
|
request.RunAndReply()
|
|
}
|
|
}
|
|
|
|
// SetMaxPending changes the queue size
|
|
func (pool *WorkerPool) SetMaxPending(value uint) {
|
|
if pool.maxPending == value {
|
|
return
|
|
}
|
|
pool.m.Lock()
|
|
defer pool.m.Unlock()
|
|
|
|
oldQueue := pool.queue
|
|
pool.queue = make(chan *Request, value)
|
|
pool.maxPending = value
|
|
|
|
close(oldQueue)
|
|
|
|
// drain the old queue and cancel requests if there are too many
|
|
for request := range oldQueue {
|
|
select {
|
|
case pool.queue <- request:
|
|
default:
|
|
request.SendErrorTooBusy("too many pending requests")
|
|
}
|
|
}
|
|
}
|
|
|
|
// SetMaxPendingDuration changes the max pending delay
|
|
func (pool *WorkerPool) SetMaxPendingDuration(value time.Duration) {
|
|
pool.m.Lock()
|
|
pool.maxPendingDuration = value
|
|
pool.m.Unlock()
|
|
}
|
|
|
|
// SetSize changes the number of workers
|
|
func (pool *WorkerPool) SetSize(size uint) {
|
|
pool.m.Lock()
|
|
defer pool.m.Unlock()
|
|
|
|
if size == pool.size {
|
|
return
|
|
}
|
|
for size < pool.size {
|
|
pool.schedule <- nil
|
|
pool.size--
|
|
}
|
|
for size > pool.size {
|
|
pool.waitGroup.Add(1)
|
|
go pool.worker()
|
|
pool.size++
|
|
}
|
|
}
|
|
|
|
// QueueRequest adds a request to the queue
|
|
// Send a SERVERTOOBUSY error to the client if the queue is full
|
|
func (pool *WorkerPool) QueueRequest(request *Request) error {
|
|
select {
|
|
case pool.getQueue() <- request:
|
|
request.setupStreamedReply()
|
|
return nil
|
|
default:
|
|
return request.SendErrorTooBusy("too many pending requests")
|
|
}
|
|
}
|
|
|
|
// Close stops all the workers and wait for their completion
|
|
// If the workers do not stop before the timeout, their context is canceled
|
|
// Will never return if a request ignores the context
|
|
func (pool *WorkerPool) Close(timeout time.Duration) {
|
|
// Stops all the workers so nothing more gets scheduled
|
|
pool.SetSize(0)
|
|
|
|
pool.m.Lock()
|
|
oldQueue := pool.queue
|
|
pool.queue = nil
|
|
pool.m.Unlock()
|
|
|
|
close(oldQueue)
|
|
for request := range oldQueue {
|
|
request.SendErrorTooBusy("Worker pool shutting down")
|
|
}
|
|
|
|
// Now wait for the workers to stop and cancel the context if they don't
|
|
timer := time.AfterFunc(timeout, pool.contextCancel)
|
|
pool.waitGroup.Wait()
|
|
timer.Stop()
|
|
close(pool.schedule)
|
|
}
|