1
0
mirror of https://github.com/go-micro/go-micro.git synced 2025-01-23 17:53:05 +02:00
2021-10-12 12:55:53 +01:00

525 lines
12 KiB
Go

package snssqs
import (
"context"
"fmt"
"regexp"
"sync"
"time"
"unicode"
"unicode/utf8"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sts"
"go-micro.dev/v4/broker"
"go-micro.dev/v4/cmd"
"go-micro.dev/v4/logger"
)
type sessClientKey struct{}
const (
defaultMaxMessages = 1
defaultVisibilityTimeout = 3
defaultWaitSeconds = 10
defaultValidateOnPublish = false
defaultValidateHeaderOnPublish = false
)
// Amazon Services
type awsServices struct {
svcSqs *sqs.SQS
svcSns *sns.SNS
sess *session.Session
accountID string
options broker.Options
}
// A subscriber (poller) to an SQS queue
type subscriber struct {
options broker.SubscribeOptions
queueName string
svc *sqs.SQS
URL string
exit chan bool
}
// A wrapper around an SQS message published on an SQS queue and delivered via subscriber
type sqsEvent struct {
sMessage *sqs.Message
svc *sqs.SQS
m *broker.Message
URL string
queueName string
err error
}
func init() {
cmd.DefaultBrokers["snssqs"] = NewBroker
}
// run is designed to run as a goroutine and poll SQS for new messages. Note that it's possible to receive
// more than one message from a single poll depending on the options configured for the plugin
func (s *subscriber) run(hdlr broker.Handler) {
logger.Debugf("SQS subscription started. Queue:%s, URL: %s", s.queueName, s.URL)
for {
select {
case <-s.exit:
return
default:
result, err := s.svc.ReceiveMessage(&sqs.ReceiveMessageInput{
QueueUrl: &s.URL,
MaxNumberOfMessages: s.getMaxMessages(),
VisibilityTimeout: s.getVisibilityTimeout(),
WaitTimeSeconds: s.getWaitSeconds(),
AttributeNames: aws.StringSlice([]string{
"SentTimestamp", // TODO: not currently exposing this to plugin users
}),
MessageAttributeNames: aws.StringSlice([]string{
"All",
}),
})
if err != nil {
time.Sleep(time.Second)
logger.Errorf("Error receiving SQS message: %s", err.Error())
continue
}
if len(result.Messages) == 0 {
time.Sleep(time.Second)
continue
}
for _, sm := range result.Messages {
s.handleMessage(sm, hdlr)
}
}
}
}
func (s *subscriber) getMaxMessages() *int64 {
if v := s.options.Context.Value(maxMessagesKey{}); v != nil {
v2 := v.(int64)
return aws.Int64(v2)
}
return aws.Int64(defaultMaxMessages)
}
func (s *subscriber) getVisibilityTimeout() *int64 {
if v := s.options.Context.Value(visibilityTimeoutKey{}); v != nil {
v2 := v.(int64)
return aws.Int64(v2)
}
return aws.Int64(defaultVisibilityTimeout)
}
func (s *subscriber) getWaitSeconds() *int64 {
if v := s.options.Context.Value(waitTimeSecondsKey{}); v != nil {
v2 := v.(int64)
return aws.Int64(v2)
}
return aws.Int64(defaultWaitSeconds)
}
func (s *subscriber) handleMessage(msg *sqs.Message, hdlr broker.Handler) {
logger.Debugf("Received SQS message: %d bytes", len(*msg.Body))
m := &broker.Message{
Header: buildMessageHeader(msg.MessageAttributes),
Body: []byte(*msg.Body),
}
p := &sqsEvent{
sMessage: msg,
m: m,
URL: s.URL,
queueName: s.queueName,
svc: s.svc,
}
if p.err = hdlr(p); p.err != nil {
fmt.Println(p.err)
}
if s.options.AutoAck {
err := p.Ack()
if err != nil {
logger.Errorf("Failed auto-acknowledge of message: %s", err.Error())
}
}
}
func (s *subscriber) Options() broker.SubscribeOptions {
return s.options
}
func (s *subscriber) Topic() string {
return s.queueName
}
func (s *subscriber) Unsubscribe() error {
select {
case <-s.exit:
return nil
default:
close(s.exit)
return nil
}
}
func (p *sqsEvent) Error() error {
return p.err
}
func (p *sqsEvent) Ack() error {
_, err := p.svc.DeleteMessage(&sqs.DeleteMessageInput{
QueueUrl: &p.URL,
ReceiptHandle: p.sMessage.ReceiptHandle,
})
return err
}
func (p *sqsEvent) Topic() string {
return p.queueName
}
func (p *sqsEvent) Message() *broker.Message {
return p.m
}
func (b *awsServices) Options() broker.Options {
return b.options
}
// AWS SDK manages the server address internally
func (b *awsServices) Address() string {
return ""
}
func (b *awsServices) Connect() error {
if svc := b.getAwsClient(); svc != nil {
b.sess = svc
return nil
}
b.sess = session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
Config: aws.Config{},
}))
sqsConfig := b.getSQSConfig()
b.svcSqs = sqs.New(b.sess, sqsConfig)
snsConfig := b.getSNSConfig()
b.svcSns = sns.New(b.sess, snsConfig)
stsConfig := b.getSTSConfig()
svcSts := sts.New(b.sess, stsConfig)
input := &sts.GetCallerIdentityInput{}
result, err := svcSts.GetCallerIdentity(input)
if err != nil {
return fmt.Errorf("unable to determine AWS AccountId: %s", err.Error())
}
b.accountID = *result.Account
return nil
}
// Disconnect does nothing as there's no live connection to terminate
func (b *awsServices) Disconnect() error {
return nil
}
// Init initializes a broker and configures an AWS session and SNSSQS struct
func (b *awsServices) Init(opts ...broker.Option) error {
for _, o := range opts {
o(&b.options)
}
return nil
}
// Publish publishes a message via SNS
func (b *awsServices) Publish(topic string, msg *broker.Message, opts ...broker.PublishOption) error {
options := broker.PublishOptions{}
for _, o := range opts {
o(&options)
}
if getValidateOnPublish(options.Context) {
if err := ValidateBody(msg); err != nil {
return err
}
}
if getValidateHeaderOnPublish(options.Context) {
if err := ValidateHeader(msg, getHeaderWhitelistOnPublish(options.Context)); err != nil {
return err
}
}
topicArn := arn.ARN{
Partition: "aws",
Service: "sns",
Region: *b.sess.Config.Region,
AccountID: b.accountID,
Resource: topic,
}.String()
input := &sns.PublishInput{
Message: aws.String(string(msg.Body[:])),
TopicArn: &topicArn,
}
input.MessageAttributes = copyMessageHeader(options.Context, msg)
logger.Debugf("Publishing SNS message to %s, %d bytes", topic, len(msg.Body))
if _, err := b.svcSns.Publish(input); err != nil {
return err
}
// Broker interfaces don't let us do anything with message ID or sequence number
return nil
}
// Subscribe subscribes to an SQS queue, starting a goroutine to poll for messages
func (b *awsServices) Subscribe(queueName string, h broker.Handler, opts ...broker.SubscribeOption) (broker.Subscriber, error) {
queueURL, err := b.urlFromQueueName(queueName)
if err != nil {
return nil, err
}
options := broker.SubscribeOptions{
AutoAck: true,
Queue: queueName,
Context: context.Background(),
}
for _, o := range opts {
o(&options)
}
subscriber := &subscriber{
options: options,
URL: queueURL,
queueName: queueName,
svc: b.svcSqs,
exit: make(chan bool),
}
go subscriber.run(h)
return subscriber, nil
}
func (b *awsServices) urlFromQueueName(queueName string) (string, error) {
resultURL, err := b.svcSqs.GetQueueUrl(&sqs.GetQueueUrlInput{
QueueName: aws.String(queueName),
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok && aerr.Code() == sqs.ErrCodeQueueDoesNotExist {
return "", fmt.Errorf("unable to find queue %s: %s", queueName, err.Error())
}
return "", fmt.Errorf("unable to determine URL for queue %s: %s", queueName, err.Error())
}
return *resultURL.QueueUrl, nil
}
// String returns the name of the broker plugin
func (b *awsServices) String() string {
return "snssqs"
}
func (b *awsServices) getAwsClient() *session.Session {
raw := b.options.Context.Value(sessClientKey{})
if raw != nil {
s := raw.(*session.Session)
return s
}
return nil
}
func (b *awsServices) getSNSConfig() *aws.Config {
raw := b.options.Context.Value(snsConfigKey{})
if raw != nil {
return raw.(*aws.Config)
}
return nil
}
func (b *awsServices) getSQSConfig() *aws.Config {
raw := b.options.Context.Value(sqsConfigKey{})
if raw != nil {
return raw.(*aws.Config)
}
return nil
}
func (b *awsServices) getSTSConfig() *aws.Config {
raw := b.options.Context.Value(stsConfigKey{})
if raw != nil {
return raw.(*aws.Config)
}
return nil
}
// NewBroker creates a new broker with options
func NewBroker(opts ...broker.Option) broker.Broker {
options := broker.Options{
Context: context.Background(),
}
for _, o := range opts {
o(&options)
}
return &awsServices{
options: options,
}
}
func copyMessageHeader(ctx context.Context, m *broker.Message) (attribs map[string]*sns.MessageAttributeValue) {
headerWhitelistOnPublish := getHeaderWhitelistOnPublish(ctx)
attribs = make(map[string]*sns.MessageAttributeValue)
for k, v := range m.Header {
if headerWhitelistOnPublish != nil {
if _, ok := headerWhitelistOnPublish[k]; !ok {
logger.Debugf("header not whitelisted, removing: %s", k)
continue
}
}
attribs[k] = &sns.MessageAttributeValue{
DataType: aws.String("String"),
StringValue: aws.String(v),
}
}
return attribs
}
func buildMessageHeader(attribs map[string]*sqs.MessageAttributeValue) map[string]string {
res := make(map[string]string)
for k, v := range attribs {
res[k] = *v.StringValue
}
return res
}
// ValidateBody Validate message for the lowest requirements of both SNS and SQS
func ValidateBody(msg *broker.Message) error {
// SNS requirements
if len(msg.Body) > 256*1024 {
return fmt.Errorf("message body over 256kB bytes")
}
if !utf8.Valid(msg.Body) {
return fmt.Errorf("message body does not consist solely of UTF-8 characters")
}
// SQS Requirements
// Only accept the following unicode ranges:
// #x9 | #xA | #xD | #x20 to #xD7FF | #xE000 to #xFFFD | #x10000 to #x10FFFF
numWorkers := 8
runeCh := make(chan rune)
var err error
waitGroup := sync.WaitGroup{}
for i := 0; i < numWorkers; i++ {
waitGroup.Add(1)
go func(wg *sync.WaitGroup, rCh <-chan rune, err *error) {
defer wg.Done()
for r := range rCh {
if !unicode.In(r, validSqsRunes) {
*err = fmt.Errorf("message body contains invalid UTF-8 characters for SQS messages")
}
}
}(&waitGroup, runeCh, &err)
}
for _, r := range string(msg.Body) {
if err != nil {
close(runeCh)
return err
}
runeCh <- r
}
close(runeCh)
waitGroup.Wait()
return err
}
func ValidateHeader(msg *broker.Message, whitelist map[string]struct{}) error {
// SNS Requirement
// can only have a max of 10 headers (converted to attributes) or silently fails
if len(msg.Header) > 10 && (whitelist == nil || len(whitelist) > 10) {
totalHeaders := len(msg.Header)
if whitelist != nil {
totalHeaders = 0
for k := range msg.Header {
if _, ok := whitelist[k]; ok {
totalHeaders++
}
}
}
return fmt.Errorf("too many headers %d (max 10)", totalHeaders)
}
// SNS Requirement
// check for allowable characters in header name (A-Z, a-z, 0-9, -, _, .)
validSNSAttrName := regexp.MustCompile(`(?i)^[A-Z0-9\-_\.]+$`)
for k := range msg.Header {
if whitelist != nil {
if _, ok := whitelist[k]; !ok {
continue
}
}
if !validSNSAttrName.MatchString(k) {
return fmt.Errorf("invlaid characters in header key %s", k)
}
}
return nil
}
func getValidateOnPublish(ctx context.Context) bool {
if ctx == nil {
return defaultValidateOnPublish
}
if v, ok := ctx.Value(validateOnPublishKey{}).(bool); ok && v {
return true
}
// false by default
return defaultValidateOnPublish
}
func getValidateHeaderOnPublish(ctx context.Context) bool {
if ctx == nil {
return defaultValidateHeaderOnPublish
}
if v, ok := ctx.Value(validateHeaderOnPublishKey{}).(bool); ok && v {
return true
}
// false by default
return defaultValidateHeaderOnPublish
}
func getHeaderWhitelistOnPublish(ctx context.Context) map[string]struct{} {
if ctx == nil {
return nil
}
if v, ok := ctx.Value(headerWhitelistOnPublishKey{}).(map[string]struct{}); ok {
return v
}
return nil
}