123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560 |
- package redriver
- import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "log"
- "strconv"
- "strings"
- "time"
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/sqs"
- "github.com/aws/aws-sdk-go-v2/service/sqs/types"
- "github.com/fgm/izidic"
- "code.osinet.fr/fgm/sqs_demo/back/services"
- )
- const (
- // BatchMax defines the maximum number of message usable in a batch operation, including redriving.
- BatchMax = 10
- // MessageSystemAttributeNameDeadLetterQueueSourceArn is an undocumented
- // types.Message attribute, used by the SQS console to support redrive.
- MessageSystemAttributeNameDeadLetterQueueSourceArn types.MessageSystemAttributeName = "DeadLetterQueueSourceArn"
- )
- var (
- ErrBatchTooBig = fmt.Errorf("operation requested on more than %d items", BatchMax)
- )
- type ItemsKeys struct {
- MessageID string
- ReceiptHandle string
- }
- type QueueRedrivePolicies struct {
- *QueueInfoAttributesRedrivePolicy
- *QueueInfoAttributesRedriveAllowPolicy
- }
- // Redriver is a redrive-oriented facade in front of the sqs.Client API.
- type Redriver interface {
- ListQueues(ctx context.Context, prefix string) (QueueUrls []string, err error)
- GetRedrivePolicies(ctx context.Context, qName string) (*QueueRedrivePolicies, error)
- GetQueueInfo(ctx context.Context, qName string) (*QueueInfo, error)
- GetQueueItems(ctx context.Context, qName string) ([]Message, error)
- DeleteItems(ctx context.Context, qName string, itemsIDs []ItemsKeys) error
- Purge(ctx context.Context, qName string) error
- RedriveItems(ctx context.Context, qName string, messages []Message) error
- }
- type redriver struct {
- VTO int32
- Wait int32
- io.Writer
- *sqs.Client
- }
- func (r *redriver) ListQueues(ctx context.Context, prefix string) (QueueUrls []string, err error) {
- // TODO implement pagination for the day we need more than 1000 queues to be reported.
- lqi := &sqs.ListQueuesInput{
- MaxResults: aws.Int32(1000),
- NextToken: nil,
- QueueNamePrefix: aws.String(prefix),
- }
- lqo, err := r.Client.ListQueues(ctx, lqi)
- if err != nil {
- return nil, fmt.Errorf("listing queues: %w", err)
- }
- return lqo.QueueUrls, nil
- }
- func (*redriver) parseQueueInfoRedrivePolicies(qName string, qao sqs.GetQueueAttributesOutput) (*QueueRedrivePolicies, error) {
- var qrp QueueRedrivePolicies
- srp := qao.Attributes[string(types.QueueAttributeNameRedrivePolicy)]
- if srp != "" {
- rp := QueueInfoAttributesRedrivePolicy{}
- if err := json.Unmarshal([]byte(srp), &rp); err != nil {
- return nil, fmt.Errorf(
- "failed parsing redrive policy for queue %q %w", qName, err)
- }
- if _, err := URLFromARNString(rp.DeadLetterTargetARN); err != nil {
- return nil, fmt.Errorf(
- "failed converting queue %q ARN to URL: %w", qName, err)
- }
- qrp.QueueInfoAttributesRedrivePolicy = &rp
- }
- srap := qao.Attributes[string(types.QueueAttributeNameRedriveAllowPolicy)]
- if srap != "" {
- rap := QueueInfoAttributesRedriveAllowPolicy{}
- if err := json.Unmarshal([]byte(srap), &rap); err != nil {
- return nil, fmt.Errorf(
- "failed parsing redrive allow policy for queue %q %w", qName, err)
- }
- for _, src := range rap.SourceQueueARNs {
- if _, err := URLFromARNString(src); err != nil {
- return nil, fmt.Errorf(
- "failed converting queue %q ARN to URL: %w", qName, err)
- }
- }
- qrp.QueueInfoAttributesRedriveAllowPolicy = &rap
- }
- return &qrp, nil
- }
- func (r *redriver) GetRedrivePolicies(ctx context.Context, qName string) (policies *QueueRedrivePolicies, err error) {
- qui := &sqs.GetQueueUrlInput{QueueName: &qName}
- qu, err := r.GetQueueUrl(ctx, qui)
- if err != nil {
- return nil, fmt.Errorf("failed getting URL for queue %q: %w", qName, err)
- }
- qai := &sqs.GetQueueAttributesInput{
- QueueUrl: qu.QueueUrl,
- AttributeNames: []types.QueueAttributeName{
- types.QueueAttributeNameRedrivePolicy,
- types.QueueAttributeNameRedriveAllowPolicy,
- },
- }
- qao, err := r.Client.GetQueueAttributes(ctx, qai)
- if err != nil {
- return nil, fmt.Errorf("failed getting DLQ policies for queue %q: %w", qName, err)
- }
- if qao == nil {
- return nil, fmt.Errorf("redrive policy info for queue %q is empty", qName)
- }
- qrp, err := r.parseQueueInfoRedrivePolicies(qName, *qao)
- if err != nil {
- return nil, fmt.Errorf("failed parsing redrive policy for queue %q: %w", qName, err)
- }
- if qrp == nil {
- return nil, nil // Queue has no DLQ: this is not an error
- }
- return qrp, nil
- }
- func (r *redriver) GetQueueInfo(ctx context.Context, qName string) (*QueueInfo, error) {
- qui := &sqs.GetQueueUrlInput{QueueName: &qName}
- qu, err := r.GetQueueUrl(ctx, qui)
- if err != nil {
- return nil, fmt.Errorf("failed getting URL for queue %q: %w", qName, err)
- }
- if qu.QueueUrl == nil {
- return nil, fmt.Errorf("URL for queue %q is empty", qName)
- }
- qai := &sqs.GetQueueAttributesInput{
- QueueUrl: qu.QueueUrl,
- AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameAll},
- }
- qao, err := r.Client.GetQueueAttributes(ctx, qai)
- if err != nil {
- return nil, fmt.Errorf("failed getting all attributes for queue %q: %w", qName, err)
- }
- if qao == nil {
- return nil, fmt.Errorf("no attributes returned for queue %q", qName)
- }
- qi, err := r.parseQueueInfoAttributes(qName, qu, qao)
- if err != nil {
- return nil, err
- }
- return qi, err
- }
- func (r *redriver) parseQueueInfoAttributes(qName string, qu *sqs.GetQueueUrlOutput, qao *sqs.GetQueueAttributesOutput) (*QueueInfo, error) {
- var (
- qi *QueueInfo
- errCount int
- anom, anomd, anomnv int64
- created, changed int64 // timestamps
- delay, max, retention, vto, wait int64
- )
- u := *qu.QueueUrl
- a, err := ARNFromURL(u)
- if err != nil {
- return nil, fmt.Errorf("ARN for queue %q cannot be parsed from URL %q: %w", qName, u, err)
- }
- for _, field := range []struct {
- name types.QueueAttributeName
- value *int64
- sDef string
- def int64
- }{
- {types.QueueAttributeNameApproximateNumberOfMessages, &anom, "-1", -1},
- {types.QueueAttributeNameApproximateNumberOfMessagesDelayed, &anomd, "-1", -1},
- {types.QueueAttributeNameApproximateNumberOfMessagesNotVisible, &anomnv, "-1", -1},
- {types.QueueAttributeNameCreatedTimestamp, &created, "0", 0},
- {types.QueueAttributeNameDelaySeconds, &delay, "-1", -1},
- {types.QueueAttributeNameLastModifiedTimestamp, &changed, "0", 0},
- {types.QueueAttributeNameMaximumMessageSize, &max, "262144", 1 << (8 + 10)}, // 256ko
- {types.QueueAttributeNameMessageRetentionPeriod, &max, "1209600", 14 * 24 * 60 * 60}, // 2 weeks
- {types.QueueAttributeNameReceiveMessageWaitTimeSeconds, &wait, "0", 0}, // short polling
- {types.QueueAttributeNameVisibilityTimeout, &vto, "0", 0}, // short polling
- } {
- s, ok := qao.Attributes[string(field.name)]
- if !ok {
- errCount++
- s = field.sDef
- }
- n, err := strconv.Atoi(s)
- if err != nil {
- errCount++
- }
- *field.value = int64(n)
- }
- qrp, err := r.parseQueueInfoRedrivePolicies(qName, *qao)
- if err != nil {
- return nil, fmt.Errorf("failed parsing redrive policy for queue %q: %w", qName, err)
- }
- qi = &QueueInfo{
- Name: qName,
- URL: u,
- Attributes: &QueueInfoAttributes{
- ApproximateNumberOfMessages: anom,
- ApproximateNumberOfMessagesDelayed: anomd,
- ApproximateNumberOfMessagesNotVisible: anomnv,
- CreatedTimestamp: created,
- DelaySeconds: delay,
- LastModifiedTimestamp: changed,
- MaximumMessageSize: max,
- MessageRetentionPeriod: retention,
- QueueARN: a.String(),
- ReceiveMessageWaitTimeSeconds: wait,
- VisibilityTimeout: vto,
- RedrivePolicy: qrp.QueueInfoAttributesRedrivePolicy,
- RedriveAllowPolicy: qrp.QueueInfoAttributesRedriveAllowPolicy,
- },
- }
- return qi, nil
- }
- func (*redriver) parseReceivedAttributes(msg types.Message) (*SystemMessageAttributes, error) {
- var (
- afrt, arc, sent int64
- errCount int
- err error
- )
- for _, field := range []struct {
- name types.MessageSystemAttributeName
- value *int64
- sDef string
- def int64
- }{
- {types.MessageSystemAttributeNameApproximateFirstReceiveTimestamp, &afrt, "0", 0},
- {types.MessageSystemAttributeNameApproximateReceiveCount, &arc, "0", 0},
- {types.MessageSystemAttributeNameSentTimestamp, &sent, "0", 0},
- } {
- s, ok := msg.Attributes[string(field.name)]
- if !ok {
- errCount++
- s = field.sDef
- }
- n, err := strconv.Atoi(s)
- if err != nil {
- errCount++
- }
- *field.value = int64(n)
- }
- ma := SystemMessageAttributes{
- ApproximateFirstReceiveTimestamp: afrt,
- ApproximateReceiveCount: arc,
- DeadLetterQueueSourceARN: msg.Attributes[string(MessageSystemAttributeNameDeadLetterQueueSourceArn)],
- SenderId: msg.Attributes[string(types.MessageSystemAttributeNameSenderId)],
- SentTimestamp: sent,
- }
- if errCount > 0 {
- err = fmt.Errorf("encounted %d errors parsing system message attributes", errCount)
- return nil, err
- }
- return &ma, nil
- }
- func (r *redriver) GetQueueItems(ctx context.Context, qName string) ([]Message, error) {
- qui := &sqs.GetQueueUrlInput{QueueName: &qName}
- qu, err := r.GetQueueUrl(ctx, qui)
- if err != nil {
- return nil, fmt.Errorf("failed getting URL for queue %q: %w", qName, err)
- }
- if qu.QueueUrl == nil {
- return nil, fmt.Errorf("URL for queue %q is empty", qName)
- }
- rmi := sqs.ReceiveMessageInput{
- QueueUrl: qu.QueueUrl,
- AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameAll},
- MaxNumberOfMessages: 10,
- MessageAttributeNames: []string{".*"},
- VisibilityTimeout: 0,
- WaitTimeSeconds: r.Wait,
- }
- t0 := time.Now()
- rmo, err := r.Client.ReceiveMessage(ctx, &rmi)
- d := time.Since(t0)
- if err != nil {
- return nil, fmt.Errorf("after %v, failed receiving messages for queue %q: %w",
- d, qName, err)
- }
- ms := make([]Message, 0, len(rmo.Messages))
- for _, m := range rmo.Messages {
- ma, err := r.parseReceivedAttributes(m)
- if err != nil {
- return nil, err
- }
- for _, field := range []struct {
- name string
- value any
- }{
- {"body", m.Body},
- {"md5 of body", m.MD5OfBody},
- {"md5 of message attributes", m.MD5OfMessageAttributes},
- {"message ID", m.MessageId},
- {"receipt handle", m.ReceiptHandle},
- } {
- if field.value == nil {
- return nil, fmt.Errorf("missing field %s on message", field.name)
- }
- }
- j, err := JSONableFromMessageAttributeValues(m.MessageAttributes)
- if err != nil {
- return nil, fmt.Errorf("failed decoding message boby: %w", err)
- }
- m2 := Message{
- Attributes: ma,
- MessageAttributes: j,
- }
- for _, pair := range []struct{ dest, src *string }{
- {&m2.Body, m.Body},
- {&m2.Md5OfBody, m.MD5OfBody},
- {&m2.Md5OfMessageAttributes, m.MD5OfMessageAttributes},
- {&m2.MessageId, m.MessageId},
- {&m2.ReceiptHandle, m.ReceiptHandle},
- } {
- if pair.src != nil {
- *pair.dest = *pair.src
- }
- }
- ms = append(ms, m2)
- }
- return ms, err
- }
- func (r *redriver) DeleteItems(ctx context.Context, qName string, keys []ItemsKeys) error {
- qui := &sqs.GetQueueUrlInput{QueueName: &qName}
- qu, err := r.GetQueueUrl(ctx, qui)
- if err != nil {
- return fmt.Errorf("failed getting URL for queue %q: %w", qName, err)
- }
- entries := make([]types.DeleteMessageBatchRequestEntry, len(keys))
- for i := 0; i < len(keys); i++ {
- entries[i] = types.DeleteMessageBatchRequestEntry{
- Id: aws.String(keys[i].MessageID),
- ReceiptHandle: aws.String(keys[i].ReceiptHandle),
- }
- }
- dmi := sqs.DeleteMessageBatchInput{
- Entries: entries,
- QueueUrl: qu.QueueUrl,
- }
- dmo, err := r.DeleteMessageBatch(ctx, &dmi)
- if err != nil {
- return fmt.Errorf("failed deleting %d items from queue %q: %v",
- len(keys), qName, err)
- }
- if len(dmo.Failed) > 0 {
- errs := make([]string, len(dmo.Failed))
- for i, bree := range dmo.Failed {
- source := "aws"
- if bree.SenderFault {
- source = "redriver"
- }
- errs[i] = fmt.Sprintf("ID: %s / Failure: %s = %q / Source: %s",
- *bree.Id, *bree.Message, *bree.Code, source)
- }
- return fmt.Errorf("failed deleting %d items out of %d from queue %q: %s",
- len(dmo.Failed), len(keys), qName, strings.Join(errs, "\n"))
- }
- log.Println(dmo)
- return nil
- }
- func (r *redriver) Purge(ctx context.Context, qName string) error {
- qui := &sqs.GetQueueUrlInput{QueueName: &qName}
- qu, err := r.GetQueueUrl(ctx, qui)
- if err != nil {
- return fmt.Errorf("failed getting URL for queue %q: %w", qName, err)
- }
- pqi := sqs.PurgeQueueInput{QueueUrl: qu.QueueUrl}
- _, err = r.PurgeQueue(ctx, &pqi)
- if err != nil {
- return fmt.Errorf("failed purging queue %q: %w",
- qName, err)
- }
- return nil
- }
- // RedriveItems redrives the selected message back to their respective source queue,
- // removing them from the DLQ once they have been sent.
- //
- // Since a queue can act as a DLQ for more than one source queue, the messages
- // sends are grouped by source queue.
- func (r *redriver) RedriveItems(ctx context.Context, dlqName string, messages []Message) error {
- sqURLs := make(map[string][]Message, 1) // In most cases, only a single queue will be used.
- for _, message := range messages {
- sARN := message.Attributes.DeadLetterQueueSourceARN
- sURL, err := URLFromARNString(sARN)
- if err != nil {
- return fmt.Errorf("failed resolving source ARN %q to URL: %v", sARN, err)
- }
- sqURLs[sURL] = append(sqURLs[sURL], message)
- }
- for qURL, messages := range sqURLs {
- if err := r.redriveQueueMessages(ctx, dlqName, qURL, messages); err != nil {
- return err
- }
- }
- return nil
- }
- // redriveQueueMessages handles message redriving for messages in a single queue.
- func (r *redriver) redriveQueueMessages(ctx context.Context, dlqName string, qURL string, messages []Message) error {
- if len(messages) > BatchMax {
- return ErrBatchTooBig
- }
- qui := &sqs.GetQueueUrlInput{QueueName: &dlqName}
- dlqURL, err := r.GetQueueUrl(ctx, qui)
- if err != nil || dlqURL == nil {
- return fmt.Errorf("failed getting URL for queue %q: %w", dlqName, err)
- }
- // Hide messages to prevent other consumers from seeing them and generating duplicates.
- fatal, nontafal := r.hideQueueMessages(ctx, "", *dlqURL.QueueUrl, messages)
- if fatal != nil {
- return fmt.Errorf("failed hiding messages during redrive towards queue %q: %w", qURL, fatal)
- }
- if nontafal != nil {
- log.Printf("Redrive nonfatal error hiding messages on queue %q: %v", dlqName, nontafal)
- }
- // Send the messages back to their source queue.
- if err := r.resendQueueMessages(ctx, qURL, messages); err != nil {
- return fmt.Errorf("failed sending messages back to queue %q: %w", qURL, err)
- }
- // Delete them from the DLQ.
- keys := make([]ItemsKeys, len(messages))
- for i, m := range messages {
- keys[i] = m.Keys()
- }
- if err := r.DeleteItems(ctx, dlqName, keys); err != nil {
- return fmt.Errorf("failed deleting messages already redriven from DLQ %s to queue %q: beware of duplicates: %w",
- dlqName, qURL, err)
- }
- return nil
- }
- func (r *redriver) resendQueueMessages(ctx context.Context, qURL string, messages []Message) error {
- smbre := make([]types.SendMessageBatchRequestEntry, len(messages))
- for i, m := range messages {
- m.MessageAttributes["previous-message-id"] = m.MessageId
- mav, err := MessageAttributeValuesFromJSONable(m.MessageAttributes)
- if err != nil {
- return fmt.Errorf("failed converting message attributes for message %s on queue %q: %v",
- m.MessageId, qURL, err)
- }
- smbre[i] = types.SendMessageBatchRequestEntry{
- Id: aws.String(strconv.Itoa(i)),
- MessageBody: &m.Body,
- MessageAttributes: mav,
- }
- }
- smbi := sqs.SendMessageBatchInput{
- Entries: smbre,
- QueueUrl: &qURL,
- }
- smbo, err := r.SendMessageBatch(ctx, &smbi)
- if err != nil {
- return fmt.Errorf("failed sending messages to queue %q: %v",
- qURL, err)
- }
- if len(smbo.Failed) == 0 {
- return nil
- }
- errs := make([]error, len(smbo.Failed))
- for _, err := range smbo.Failed {
- msg := fmt.Sprintf("ID: %s, Code: %s, Message: %s", *err.Id, *err.Code, *err.Message)
- if err.SenderFault {
- msg += " (sender fault)"
- }
- errs = append(errs, errors.New(msg))
- }
- return fmt.Errorf("partial redrive: failed re-sending %d/%d messages, %v",
- len(smbo.Failed), len(smbi.Entries), errs)
- }
- func (r *redriver) hideQueueMessages(ctx context.Context, dlqName string, qURL string, messages []Message) (fatal, nonfatal error) {
- cmvbre := make([]types.ChangeMessageVisibilityBatchRequestEntry, len(messages))
- for i, m := range messages {
- cmvbre[i] = types.ChangeMessageVisibilityBatchRequestEntry{
- Id: aws.String(strconv.Itoa(i)),
- ReceiptHandle: aws.String(m.ReceiptHandle),
- VisibilityTimeout: r.VTO,
- }
- }
- cmvbi := sqs.ChangeMessageVisibilityBatchInput{
- Entries: cmvbre,
- QueueUrl: aws.String(qURL),
- }
- cmvbo, err := r.ChangeMessageVisibilityBatch(ctx, &cmvbi)
- if err != nil {
- return fmt.Errorf("failed hiding request on DLQ %q: %w", dlqName, err), nil
- }
- switch len(cmvbo.Failed) {
- case len(cmvbi.Entries):
- // No message made it: abort.
- return fmt.Errorf("failed hiding all %d messages on DLQ %q", len(cmvbi.Entries), dlqName), nil
- case 0:
- return nil, nil // All well
- default:
- errs := make([]error, len(cmvbo.Failed))
- for _, err := range cmvbo.Failed {
- msg := fmt.Sprintf("ID: %s, Code: %s, Message: %s", *err.Id, *err.Code, *err.Message)
- if err.SenderFault {
- msg += " (sender fault)"
- }
- errs = append(errs, errors.New(msg))
- }
- // Some message made it, crossing fingers.
- return nil, fmt.Errorf("failed hiding %d/%d messages, %v",
- len(cmvbo.Failed), len(cmvbi.Entries), errs)
- }
- }
- func RedriverService(dic *izidic.Container) (any, error) {
- cli := dic.MustService(services.SvcClient).(*sqs.Client)
- w := dic.MustParam(services.PWriter).(io.Writer)
- vto := dic.MustParam(services.PVTO).(time.Duration)
- wait := int32(dic.MustParam(services.PWait).(int))
- return &redriver{
- Client: cli,
- VTO: int32(vto.Seconds()),
- Wait: wait,
- Writer: w,
- }, nil
- }
|