1
0
mirror of https://github.com/go-micro/go-micro.git synced 2025-01-05 10:20:53 +02:00
go-micro/plugins/transport/rabbitmq/rabbitmq.go
2021-10-12 12:55:53 +01:00

460 lines
7.8 KiB
Go

// Package rabbitmq provides a RabbitMQ transport
package rabbitmq
import (
"fmt"
"io"
"sync"
"time"
"errors"
"github.com/google/uuid"
"github.com/streadway/amqp"
"go-micro.dev/v4/cmd"
"go-micro.dev/v4/transport"
)
const (
directReplyQueue = "amq.rabbitmq.reply-to"
)
type rmqtport struct {
conn *rabbitMQConn
addrs []string
opts transport.Options
once sync.Once
replyTo string
sync.Mutex
inflight map[string]chan amqp.Delivery
}
type rmqtportClient struct {
rt *rmqtport
addr string
corId string
local string
remote string
reply chan amqp.Delivery
}
type rmqtportSocket struct {
rt *rmqtport
conn *rabbitMQConn
d *amqp.Delivery
close chan bool
local string
remote string
sync.Mutex
r chan *amqp.Delivery
bl []*amqp.Delivery
}
type rmqtportListener struct {
rt *rmqtport
conn *rabbitMQConn
exit chan bool
addr string
sync.RWMutex
so map[string]*rmqtportSocket
}
var (
DefaultTimeout = time.Minute
)
func init() {
cmd.DefaultTransports["rabbitmq"] = NewTransport
}
func (r *rmqtportClient) Local() string {
return r.local
}
func (r *rmqtportClient) Remote() string {
return r.remote
}
func (r *rmqtportClient) Send(m *transport.Message) error {
if !r.rt.conn.IsConnected() {
return errors.New("Not connected to AMQP")
}
headers := amqp.Table{}
for k, v := range m.Header {
headers[k] = v
}
message := amqp.Publishing{
CorrelationId: r.corId,
Timestamp: time.Now().UTC(),
Body: m.Body,
ReplyTo: r.rt.replyTo,
Headers: headers,
}
// no timeout
if r.rt.opts.Timeout == time.Duration(0) {
return r.rt.conn.Publish(DefaultExchange, r.addr, message)
}
// use the timeout
ch := make(chan error, 1)
go func() {
ch <- r.rt.conn.Publish(DefaultExchange, r.addr, message)
}()
select {
case err := <-ch:
return err
case <-time.After(r.rt.opts.Timeout):
return errors.New("timed out")
}
}
func (r *rmqtportClient) Recv(m *transport.Message) error {
timeout := DefaultTimeout
if r.rt.opts.Timeout > time.Duration(0) {
timeout = r.rt.opts.Timeout
}
select {
case d := <-r.reply:
mr := &transport.Message{
Header: make(map[string]string),
Body: d.Body,
}
for k, v := range d.Headers {
mr.Header[k] = fmt.Sprintf("%v", v)
}
*m = *mr
return nil
case <-time.After(timeout):
return errors.New("timed out")
}
}
func (r *rmqtportClient) Close() error {
r.rt.popReq(r.corId)
return nil
}
func (r *rmqtportSocket) Local() string {
return r.local
}
func (r *rmqtportSocket) Remote() string {
return r.remote
}
func (r *rmqtportSocket) Recv(m *transport.Message) error {
if m == nil {
return errors.New("message passed in is nil")
}
var d *amqp.Delivery
var ok bool
if r.rt.opts.Timeout > time.Duration(0) {
select {
case d, ok = <-r.r:
case <-time.After(r.rt.opts.Timeout):
return errors.New("timed out")
}
} else {
d, ok = <-r.r
}
if !ok {
return io.EOF
}
r.Lock()
if len(r.bl) > 0 {
select {
case r.r <- r.bl[0]:
r.bl = r.bl[1:]
default:
}
}
r.Unlock()
mr := &transport.Message{
Header: make(map[string]string),
Body: d.Body,
}
for k, v := range d.Headers {
mr.Header[k] = fmt.Sprintf("%v", v)
}
*m = *mr
return nil
}
func (r *rmqtportSocket) Send(m *transport.Message) error {
msg := amqp.Publishing{
CorrelationId: r.d.CorrelationId,
Timestamp: time.Now().UTC(),
Body: m.Body,
Headers: amqp.Table{},
}
for k, v := range m.Header {
msg.Headers[k] = v
}
// no timeout
if r.rt.opts.Timeout == time.Duration(0) {
return r.conn.Publish("", r.d.ReplyTo, msg)
}
// use the timeout
ch := make(chan error, 1)
go func() {
ch <- r.conn.Publish("", r.d.ReplyTo, msg)
}()
select {
case err := <-ch:
return err
case <-time.After(r.rt.opts.Timeout):
return errors.New("timed out")
}
}
func (r *rmqtportSocket) Close() error {
select {
case <-r.close:
return nil
default:
close(r.close)
}
return nil
}
func (r *rmqtportListener) Addr() string {
return r.addr
}
func (r *rmqtportListener) Close() error {
r.exit <- true
r.conn.Close()
return nil
}
func (r *rmqtportListener) Accept(fn func(transport.Socket)) error {
for {
// connect if not connected
if !r.conn.IsConnected() {
// reinitialise
<-r.conn.Init(r.rt.opts.Secure, r.rt.opts.TLSConfig)
}
// accept connections
exit, err := r.accept(fn)
if err != nil {
return err
}
// connection closed
if exit {
return nil
}
}
}
func (r *rmqtportListener) accept(fn func(transport.Socket)) (bool, error) {
deliveries, err := r.conn.Consume(r.addr)
if err != nil {
return false, err
}
for {
select {
case <-r.exit:
return true, nil
case d, ok := <-deliveries:
if !ok {
return false, nil
}
r.RLock()
sock, ok := r.so[d.CorrelationId]
r.RUnlock()
if !ok {
sock = &rmqtportSocket{
rt: r.rt,
d: &d,
r: make(chan *amqp.Delivery, 1),
conn: r.conn,
close: make(chan bool, 1),
local: r.Addr(),
remote: d.CorrelationId,
}
r.Lock()
r.so[sock.d.CorrelationId] = sock
r.Unlock()
go func() {
<-sock.close
r.Lock()
delete(r.so, sock.d.CorrelationId)
r.Unlock()
}()
go fn(sock)
}
select {
case <-sock.close:
continue
default:
}
sock.Lock()
sock.bl = append(sock.bl, &d)
select {
case sock.r <- sock.bl[0]:
sock.bl = sock.bl[1:]
default:
}
sock.Unlock()
}
}
return false, nil
}
func (r *rmqtport) putReq(id string) chan amqp.Delivery {
r.Lock()
ch := make(chan amqp.Delivery, 1)
r.inflight[id] = ch
r.Unlock()
return ch
}
func (r *rmqtport) getReq(id string) chan amqp.Delivery {
r.Lock()
defer r.Unlock()
if ch, ok := r.inflight[id]; ok {
return ch
}
return nil
}
func (r *rmqtport) popReq(id string) {
r.Lock()
defer r.Unlock()
if _, ok := r.inflight[id]; ok {
delete(r.inflight, id)
}
}
func (r *rmqtport) init() {
<-r.conn.Init(r.opts.Secure, r.opts.TLSConfig)
if err := r.conn.Channel.DeclareReplyQueue(r.replyTo); err != nil {
return
}
deliveries, err := r.conn.Channel.ConsumeQueue(r.replyTo)
if err != nil {
return
}
go func() {
for delivery := range deliveries {
go r.handle(delivery)
}
}()
}
func (r *rmqtport) handle(delivery amqp.Delivery) {
ch := r.getReq(delivery.CorrelationId)
if ch == nil {
return
}
ch <- delivery
}
func (r *rmqtport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) {
id, err := uuid.NewRandom()
if err != nil {
return nil, err
}
r.once.Do(r.init)
return &rmqtportClient{
rt: r,
addr: addr,
corId: id.String(),
reply: r.putReq(id.String()),
local: id.String(),
remote: addr,
}, nil
}
func (r *rmqtport) Listen(addr string, opts ...transport.ListenOption) (transport.Listener, error) {
if len(addr) == 0 || addr == ":0" {
id, err := uuid.NewRandom()
if err != nil {
return nil, err
}
addr = id.String()
}
conn := newRabbitMQConn("", r.addrs)
<-conn.Init(r.opts.Secure, r.opts.TLSConfig)
return &rmqtportListener{
rt: r,
addr: addr,
conn: conn,
exit: make(chan bool, 1),
so: make(map[string]*rmqtportSocket),
}, nil
}
func (r *rmqtport) Init(opts ...transport.Option) error {
for _, o := range opts {
o(&r.opts)
}
r.addrs = r.opts.Addrs
r.conn.Close()
r.conn = newRabbitMQConn("", r.opts.Addrs)
return nil
}
func (r *rmqtport) Options() transport.Options {
return r.opts
}
func (r *rmqtport) String() string {
return "rabbitmq"
}
func NewTransport(opts ...transport.Option) transport.Transport {
options := transport.Options{
Timeout: DefaultTimeout,
}
for _, o := range opts {
o(&options)
}
return &rmqtport{
opts: options,
conn: newRabbitMQConn("", options.Addrs),
addrs: options.Addrs,
replyTo: directReplyQueue,
inflight: make(map[string]chan amqp.Delivery),
}
}