package redriver import ( "context" "encoding/json" "errors" "fmt" "io" "log" "strconv" "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/sqs/types" "github.com/fgm/izidic" "code.osinet.fr/fgm/sqs_demo/back/services" ) const ( // BatchMax defines the maximum number of message usable in a batch operation, including redriving. BatchMax = 10 // MessageSystemAttributeNameDeadLetterQueueSourceArn is an undocumented // types.Message attribute, used by the SQS console to support redrive. MessageSystemAttributeNameDeadLetterQueueSourceArn types.MessageSystemAttributeName = "DeadLetterQueueSourceArn" ) var ( ErrBatchTooBig = fmt.Errorf("operation requested on more than %d items", BatchMax) ) type ItemsKeys struct { MessageID string ReceiptHandle string } type QueueRedrivePolicies struct { *QueueInfoAttributesRedrivePolicy *QueueInfoAttributesRedriveAllowPolicy } // Redriver is a redrive-oriented facade in front of the sqs.Client API. type Redriver interface { ListQueues(ctx context.Context, prefix string) (QueueUrls []string, err error) GetRedrivePolicies(ctx context.Context, qName string) (*QueueRedrivePolicies, error) GetQueueInfo(ctx context.Context, qName string) (*QueueInfo, error) GetQueueItems(ctx context.Context, qName string) ([]Message, error) DeleteItems(ctx context.Context, qName string, itemsIDs []ItemsKeys) error Purge(ctx context.Context, qName string) error RedriveItems(ctx context.Context, qName string, messages []Message) error } type redriver struct { VTO int32 Wait int32 io.Writer *sqs.Client } func (r *redriver) ListQueues(ctx context.Context, prefix string) (QueueUrls []string, err error) { // TODO implement pagination for the day we need more than 1000 queues to be reported. lqi := &sqs.ListQueuesInput{ MaxResults: aws.Int32(1000), NextToken: nil, QueueNamePrefix: aws.String(prefix), } lqo, err := r.Client.ListQueues(ctx, lqi) if err != nil { return nil, fmt.Errorf("listing queues: %w", err) } return lqo.QueueUrls, nil } func (*redriver) parseQueueInfoRedrivePolicies(qName string, qao sqs.GetQueueAttributesOutput) (*QueueRedrivePolicies, error) { var qrp QueueRedrivePolicies srp := qao.Attributes[string(types.QueueAttributeNameRedrivePolicy)] if srp != "" { rp := QueueInfoAttributesRedrivePolicy{} if err := json.Unmarshal([]byte(srp), &rp); err != nil { return nil, fmt.Errorf( "failed parsing redrive policy for queue %q %w", qName, err) } if _, err := URLFromARNString(rp.DeadLetterTargetARN); err != nil { return nil, fmt.Errorf( "failed converting queue %q ARN to URL: %w", qName, err) } qrp.QueueInfoAttributesRedrivePolicy = &rp } srap := qao.Attributes[string(types.QueueAttributeNameRedriveAllowPolicy)] if srap != "" { rap := QueueInfoAttributesRedriveAllowPolicy{} if err := json.Unmarshal([]byte(srap), &rap); err != nil { return nil, fmt.Errorf( "failed parsing redrive allow policy for queue %q %w", qName, err) } for _, src := range rap.SourceQueueARNs { if _, err := URLFromARNString(src); err != nil { return nil, fmt.Errorf( "failed converting queue %q ARN to URL: %w", qName, err) } } qrp.QueueInfoAttributesRedriveAllowPolicy = &rap } return &qrp, nil } func (r *redriver) GetRedrivePolicies(ctx context.Context, qName string) (policies *QueueRedrivePolicies, err error) { qui := &sqs.GetQueueUrlInput{QueueName: &qName} qu, err := r.GetQueueUrl(ctx, qui) if err != nil { return nil, fmt.Errorf("failed getting URL for queue %q: %w", qName, err) } qai := &sqs.GetQueueAttributesInput{ QueueUrl: qu.QueueUrl, AttributeNames: []types.QueueAttributeName{ types.QueueAttributeNameRedrivePolicy, types.QueueAttributeNameRedriveAllowPolicy, }, } qao, err := r.Client.GetQueueAttributes(ctx, qai) if err != nil { return nil, fmt.Errorf("failed getting DLQ policies for queue %q: %w", qName, err) } if qao == nil { return nil, fmt.Errorf("redrive policy info for queue %q is empty", qName) } qrp, err := r.parseQueueInfoRedrivePolicies(qName, *qao) if err != nil { return nil, fmt.Errorf("failed parsing redrive policy for queue %q: %w", qName, err) } if qrp == nil { return nil, nil // Queue has no DLQ: this is not an error } return qrp, nil } func (r *redriver) GetQueueInfo(ctx context.Context, qName string) (*QueueInfo, error) { qui := &sqs.GetQueueUrlInput{QueueName: &qName} qu, err := r.GetQueueUrl(ctx, qui) if err != nil { return nil, fmt.Errorf("failed getting URL for queue %q: %w", qName, err) } if qu.QueueUrl == nil { return nil, fmt.Errorf("URL for queue %q is empty", qName) } qai := &sqs.GetQueueAttributesInput{ QueueUrl: qu.QueueUrl, AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameAll}, } qao, err := r.Client.GetQueueAttributes(ctx, qai) if err != nil { return nil, fmt.Errorf("failed getting all attributes for queue %q: %w", qName, err) } if qao == nil { return nil, fmt.Errorf("no attributes returned for queue %q", qName) } qi, err := r.parseQueueInfoAttributes(qName, qu, qao) if err != nil { return nil, err } return qi, err } func (r *redriver) parseQueueInfoAttributes(qName string, qu *sqs.GetQueueUrlOutput, qao *sqs.GetQueueAttributesOutput) (*QueueInfo, error) { var ( qi *QueueInfo errCount int anom, anomd, anomnv int64 created, changed int64 // timestamps delay, max, retention, vto, wait int64 ) u := *qu.QueueUrl a, err := ARNFromURL(u) if err != nil { return nil, fmt.Errorf("ARN for queue %q cannot be parsed from URL %q: %w", qName, u, err) } for _, field := range []struct { name types.QueueAttributeName value *int64 sDef string def int64 }{ {types.QueueAttributeNameApproximateNumberOfMessages, &anom, "-1", -1}, {types.QueueAttributeNameApproximateNumberOfMessagesDelayed, &anomd, "-1", -1}, {types.QueueAttributeNameApproximateNumberOfMessagesNotVisible, &anomnv, "-1", -1}, {types.QueueAttributeNameCreatedTimestamp, &created, "0", 0}, {types.QueueAttributeNameDelaySeconds, &delay, "-1", -1}, {types.QueueAttributeNameLastModifiedTimestamp, &changed, "0", 0}, {types.QueueAttributeNameMaximumMessageSize, &max, "262144", 1 << (8 + 10)}, // 256ko {types.QueueAttributeNameMessageRetentionPeriod, &max, "1209600", 14 * 24 * 60 * 60}, // 2 weeks {types.QueueAttributeNameReceiveMessageWaitTimeSeconds, &wait, "0", 0}, // short polling {types.QueueAttributeNameVisibilityTimeout, &vto, "0", 0}, // short polling } { s, ok := qao.Attributes[string(field.name)] if !ok { errCount++ s = field.sDef } n, err := strconv.Atoi(s) if err != nil { errCount++ } *field.value = int64(n) } qrp, err := r.parseQueueInfoRedrivePolicies(qName, *qao) if err != nil { return nil, fmt.Errorf("failed parsing redrive policy for queue %q: %w", qName, err) } qi = &QueueInfo{ Name: qName, URL: u, Attributes: &QueueInfoAttributes{ ApproximateNumberOfMessages: anom, ApproximateNumberOfMessagesDelayed: anomd, ApproximateNumberOfMessagesNotVisible: anomnv, CreatedTimestamp: created, DelaySeconds: delay, LastModifiedTimestamp: changed, MaximumMessageSize: max, MessageRetentionPeriod: retention, QueueARN: a.String(), ReceiveMessageWaitTimeSeconds: wait, VisibilityTimeout: vto, RedrivePolicy: qrp.QueueInfoAttributesRedrivePolicy, RedriveAllowPolicy: qrp.QueueInfoAttributesRedriveAllowPolicy, }, } return qi, nil } func (*redriver) parseReceivedAttributes(msg types.Message) (*SystemMessageAttributes, error) { var ( afrt, arc, sent int64 errCount int err error ) for _, field := range []struct { name types.MessageSystemAttributeName value *int64 sDef string def int64 }{ {types.MessageSystemAttributeNameApproximateFirstReceiveTimestamp, &afrt, "0", 0}, {types.MessageSystemAttributeNameApproximateReceiveCount, &arc, "0", 0}, {types.MessageSystemAttributeNameSentTimestamp, &sent, "0", 0}, } { s, ok := msg.Attributes[string(field.name)] if !ok { errCount++ s = field.sDef } n, err := strconv.Atoi(s) if err != nil { errCount++ } *field.value = int64(n) } ma := SystemMessageAttributes{ ApproximateFirstReceiveTimestamp: afrt, ApproximateReceiveCount: arc, DeadLetterQueueSourceARN: msg.Attributes[string(MessageSystemAttributeNameDeadLetterQueueSourceArn)], SenderId: msg.Attributes[string(types.MessageSystemAttributeNameSenderId)], SentTimestamp: sent, } if errCount > 0 { err = fmt.Errorf("encounted %d errors parsing system message attributes", errCount) return nil, err } return &ma, nil } func (r *redriver) GetQueueItems(ctx context.Context, qName string) ([]Message, error) { qui := &sqs.GetQueueUrlInput{QueueName: &qName} qu, err := r.GetQueueUrl(ctx, qui) if err != nil { return nil, fmt.Errorf("failed getting URL for queue %q: %w", qName, err) } if qu.QueueUrl == nil { return nil, fmt.Errorf("URL for queue %q is empty", qName) } rmi := sqs.ReceiveMessageInput{ QueueUrl: qu.QueueUrl, AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameAll}, MaxNumberOfMessages: 10, MessageAttributeNames: []string{".*"}, VisibilityTimeout: 0, WaitTimeSeconds: r.Wait, } t0 := time.Now() rmo, err := r.Client.ReceiveMessage(ctx, &rmi) d := time.Since(t0) if err != nil { return nil, fmt.Errorf("after %v, failed receiving messages for queue %q: %w", d, qName, err) } ms := make([]Message, 0, len(rmo.Messages)) for _, m := range rmo.Messages { ma, err := r.parseReceivedAttributes(m) if err != nil { return nil, err } for _, field := range []struct { name string value any }{ {"body", m.Body}, {"md5 of body", m.MD5OfBody}, {"md5 of message attributes", m.MD5OfMessageAttributes}, {"message ID", m.MessageId}, {"receipt handle", m.ReceiptHandle}, } { if field.value == nil { return nil, fmt.Errorf("missing field %s on message", field.name) } } j, err := JSONableFromMessageAttributeValues(m.MessageAttributes) if err != nil { return nil, fmt.Errorf("failed decoding message boby: %w", err) } m2 := Message{ Attributes: ma, MessageAttributes: j, } for _, pair := range []struct{ dest, src *string }{ {&m2.Body, m.Body}, {&m2.Md5OfBody, m.MD5OfBody}, {&m2.Md5OfMessageAttributes, m.MD5OfMessageAttributes}, {&m2.MessageId, m.MessageId}, {&m2.ReceiptHandle, m.ReceiptHandle}, } { if pair.src != nil { *pair.dest = *pair.src } } ms = append(ms, m2) } return ms, err } func (r *redriver) DeleteItems(ctx context.Context, qName string, keys []ItemsKeys) error { qui := &sqs.GetQueueUrlInput{QueueName: &qName} qu, err := r.GetQueueUrl(ctx, qui) if err != nil { return fmt.Errorf("failed getting URL for queue %q: %w", qName, err) } entries := make([]types.DeleteMessageBatchRequestEntry, len(keys)) for i := 0; i < len(keys); i++ { entries[i] = types.DeleteMessageBatchRequestEntry{ Id: aws.String(keys[i].MessageID), ReceiptHandle: aws.String(keys[i].ReceiptHandle), } } dmi := sqs.DeleteMessageBatchInput{ Entries: entries, QueueUrl: qu.QueueUrl, } dmo, err := r.DeleteMessageBatch(ctx, &dmi) if err != nil { return fmt.Errorf("failed deleting %d items from queue %q: %v", len(keys), qName, err) } if len(dmo.Failed) > 0 { errs := make([]string, len(dmo.Failed)) for i, bree := range dmo.Failed { source := "aws" if bree.SenderFault { source = "redriver" } errs[i] = fmt.Sprintf("ID: %s / Failure: %s = %q / Source: %s", *bree.Id, *bree.Message, *bree.Code, source) } return fmt.Errorf("failed deleting %d items out of %d from queue %q: %s", len(dmo.Failed), len(keys), qName, strings.Join(errs, "\n")) } log.Println(dmo) return nil } func (r *redriver) Purge(ctx context.Context, qName string) error { qui := &sqs.GetQueueUrlInput{QueueName: &qName} qu, err := r.GetQueueUrl(ctx, qui) if err != nil { return fmt.Errorf("failed getting URL for queue %q: %w", qName, err) } pqi := sqs.PurgeQueueInput{QueueUrl: qu.QueueUrl} _, err = r.PurgeQueue(ctx, &pqi) if err != nil { return fmt.Errorf("failed purging queue %q: %w", qName, err) } return nil } // RedriveItems redrives the selected message back to their respective source queue, // removing them from the DLQ once they have been sent. // // Since a queue can act as a DLQ for more than one source queue, the messages // sends are grouped by source queue. func (r *redriver) RedriveItems(ctx context.Context, dlqName string, messages []Message) error { sqURLs := make(map[string][]Message, 1) // In most cases, only a single queue will be used. for _, message := range messages { sARN := message.Attributes.DeadLetterQueueSourceARN sURL, err := URLFromARNString(sARN) if err != nil { return fmt.Errorf("failed resolving source ARN %q to URL: %v", sARN, err) } sqURLs[sURL] = append(sqURLs[sURL], message) } for qURL, messages := range sqURLs { if err := r.redriveQueueMessages(ctx, dlqName, qURL, messages); err != nil { return err } } return nil } // redriveQueueMessages handles message redriving for messages in a single queue. func (r *redriver) redriveQueueMessages(ctx context.Context, dlqName string, qURL string, messages []Message) error { if len(messages) > BatchMax { return ErrBatchTooBig } qui := &sqs.GetQueueUrlInput{QueueName: &dlqName} dlqURL, err := r.GetQueueUrl(ctx, qui) if err != nil || dlqURL == nil { return fmt.Errorf("failed getting URL for queue %q: %w", dlqName, err) } // Hide messages to prevent other consumers from seeing them and generating duplicates. fatal, nontafal := r.hideQueueMessages(ctx, "", *dlqURL.QueueUrl, messages) if fatal != nil { return fmt.Errorf("failed hiding messages during redrive towards queue %q: %w", qURL, fatal) } if nontafal != nil { log.Printf("Redrive nonfatal error hiding messages on queue %q: %v", dlqName, nontafal) } // Send the messages back to their source queue. if err := r.resendQueueMessages(ctx, qURL, messages); err != nil { return fmt.Errorf("failed sending messages back to queue %q: %w", qURL, err) } // Delete them from the DLQ. keys := make([]ItemsKeys, len(messages)) for i, m := range messages { keys[i] = m.Keys() } if err := r.DeleteItems(ctx, dlqName, keys); err != nil { return fmt.Errorf("failed deleting messages already redriven from DLQ %s to queue %q: beware of duplicates: %w", dlqName, qURL, err) } return nil } func (r *redriver) resendQueueMessages(ctx context.Context, qURL string, messages []Message) error { smbre := make([]types.SendMessageBatchRequestEntry, len(messages)) for i, m := range messages { m.MessageAttributes["previous-message-id"] = m.MessageId mav, err := MessageAttributeValuesFromJSONable(m.MessageAttributes) if err != nil { return fmt.Errorf("failed converting message attributes for message %s on queue %q: %v", m.MessageId, qURL, err) } smbre[i] = types.SendMessageBatchRequestEntry{ Id: aws.String(strconv.Itoa(i)), MessageBody: &m.Body, MessageAttributes: mav, } } smbi := sqs.SendMessageBatchInput{ Entries: smbre, QueueUrl: &qURL, } smbo, err := r.SendMessageBatch(ctx, &smbi) if err != nil { return fmt.Errorf("failed sending messages to queue %q: %v", qURL, err) } if len(smbo.Failed) == 0 { return nil } errs := make([]error, len(smbo.Failed)) for _, err := range smbo.Failed { msg := fmt.Sprintf("ID: %s, Code: %s, Message: %s", *err.Id, *err.Code, *err.Message) if err.SenderFault { msg += " (sender fault)" } errs = append(errs, errors.New(msg)) } return fmt.Errorf("partial redrive: failed re-sending %d/%d messages, %v", len(smbo.Failed), len(smbi.Entries), errs) } func (r *redriver) hideQueueMessages(ctx context.Context, dlqName string, qURL string, messages []Message) (fatal, nonfatal error) { cmvbre := make([]types.ChangeMessageVisibilityBatchRequestEntry, len(messages)) for i, m := range messages { cmvbre[i] = types.ChangeMessageVisibilityBatchRequestEntry{ Id: aws.String(strconv.Itoa(i)), ReceiptHandle: aws.String(m.ReceiptHandle), VisibilityTimeout: r.VTO, } } cmvbi := sqs.ChangeMessageVisibilityBatchInput{ Entries: cmvbre, QueueUrl: aws.String(qURL), } cmvbo, err := r.ChangeMessageVisibilityBatch(ctx, &cmvbi) if err != nil { return fmt.Errorf("failed hiding request on DLQ %q: %w", dlqName, err), nil } switch len(cmvbo.Failed) { case len(cmvbi.Entries): // No message made it: abort. return fmt.Errorf("failed hiding all %d messages on DLQ %q", len(cmvbi.Entries), dlqName), nil case 0: return nil, nil // All well default: errs := make([]error, len(cmvbo.Failed)) for _, err := range cmvbo.Failed { msg := fmt.Sprintf("ID: %s, Code: %s, Message: %s", *err.Id, *err.Code, *err.Message) if err.SenderFault { msg += " (sender fault)" } errs = append(errs, errors.New(msg)) } // Some message made it, crossing fingers. return nil, fmt.Errorf("failed hiding %d/%d messages, %v", len(cmvbo.Failed), len(cmvbi.Entries), errs) } } func RedriverService(dic *izidic.Container) (any, error) { cli := dic.MustService(services.SvcClient).(*sqs.Client) w := dic.MustParam(services.PWriter).(io.Writer) vto := dic.MustParam(services.PVTO).(time.Duration) wait := int32(dic.MustParam(services.PWait).(int)) return &redriver{ Client: cli, VTO: int32(vto.Seconds()), Wait: wait, Writer: w, }, nil }