package services

import (
	"context"
	"crypto/md5"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"io"
	"log"
	"strconv"
	"time"

	"github.com/aws/aws-sdk-go-v2/service/sqs"
	"github.com/aws/aws-sdk-go-v2/service/sqs/types"
	"github.com/davecgh/go-spew/spew"
	"github.com/fgm/izidic"
	"github.com/google/uuid"
)

type Event struct {
	MessageID uuid.UUID `json:"MessageId"`
	BodySum   string    `json:"MD5OfBody"`
	AttrSum   string    `json:"MD5MofMessageAttributes"`

	EventAttributes
	MessageAttributes map[string]types.MessageAttributeValue
	Body              []byte
}

func (e Event) IsRetryable() bool {
	maybeRetry, ok := e.MessageAttributes["retry"]
	if !ok {
		return false
	}
	if maybeRetry.DataType == nil || *maybeRetry.DataType != "String" || maybeRetry.StringValue == nil {
		return false
	}
	return *maybeRetry.StringValue == "1"
}

type EventAttributes struct {
	SenderID                         string    `json:"SenderId"`
	SentTime                         time.Time `json:"SentTimestamp"`
	ApproximateReceiveCount          int       `json:"ApproximateReceiveCount"`
	ApproximateFirstReceiveTimestamp time.Time `json:"ApproximateFirstReceiveTimestamp"`
}

type Handler func(ctx context.Context, enc *json.Encoder, msgID uuid.UUID, sent time.Time, input []byte, meta map[string]types.MessageAttributeValue) error

type message types.Message

func (m message) String() string {
	if m.MessageId == nil {
		return "0"
	}
	return *m.MessageId
}

func ConsumerService(dic *izidic.Container) (any, error) {
	cli := dic.MustService("sqs").(*sqs.Client)
	w := dic.MustParam(PWriter).(io.Writer)
	hdl := dic.MustParam(PHandler).(Handler)
	enc := json.NewEncoder(w)
	return func(ctx context.Context, qURL string) error {
		return consumeMessage(ctx, w, enc, cli, qURL, hdl)
	}, nil
}

func ReceiverService(dic *izidic.Container) (any, error) {
	cli := dic.MustService("sqs").(*sqs.Client)
	w := dic.MustParam(PWriter).(io.Writer)
	return func(ctx context.Context, qURL string) {
		receiveMessage(ctx, w, cli, qURL)
	}, nil
}

func consumeMessage(ctx context.Context, _ io.Writer, enc *json.Encoder, client *sqs.Client, qURL string, hdl Handler) error {
	rmi := sqs.ReceiveMessageInput{
		QueueUrl:              &qURL,
		AttributeNames:        []types.QueueAttributeName{"All"},
		MessageAttributeNames: []string{"All"},
		MaxNumberOfMessages:   1, // Default, also used when set to 0
		VisibilityTimeout:     1,
		WaitTimeSeconds:       5,
	}

	for {
		recv, err := client.ReceiveMessage(ctx, &rmi)
		if err != nil {
			return fmt.Errorf("failed receiving from queue: %w, aborting", err)
		}
		if len(recv.Messages) == 0 {
			log.Printf("No message with %d seconds timeout\n", rmi.WaitTimeSeconds)
			continue
		}
		if len(recv.Messages) != 1 {
			return fmt.Errorf("unexpected number of messages: %d, expected 0 or 1, aborting", len(recv.Messages))
		}
		msg := message(recv.Messages[0])
		evt, err := validateMessage(msg)
		if err != nil {
			log.Printf("invalid message %s: %v, dropping it anyway", msg, err)
		} else {
			if err := hdl(ctx, enc, evt.MessageID, evt.SentTime, evt.Body, evt.MessageAttributes); err != nil {
				log.Printf("message %s failed processing : %v, dropping it anyway\n", msg, err)
			} else {
				log.Printf("message %s processed successfully\n", msg)
			}
		}
		if evt.IsRetryable() {
			log.Printf("message %s not deleted, for retry", msg)
		} else {
			dmi := sqs.DeleteMessageInput{
				QueueUrl:      &qURL,
				ReceiptHandle: msg.ReceiptHandle,
			}
			_, err = client.DeleteMessage(ctx, &dmi)
			if err != nil {
				log.Printf("Error deleting message %s after successful processing: %v\n", msg, err)
				continue
			}
			log.Printf("message %s deleted after processing\n", msg)
		}
	}
}

func receiveMessage(ctx context.Context, w io.Writer, client *sqs.Client, qURL string) {
	rmi := sqs.ReceiveMessageInput{
		QueueUrl:              &qURL,
		AttributeNames:        []types.QueueAttributeName{"All"},
		MessageAttributeNames: []string{"All"},
		VisibilityTimeout:     1,
		WaitTimeSeconds:       5,
	}
	msg, err := client.ReceiveMessage(ctx, &rmi)
	if err != nil {
		log.Fatalf("failed receiving from queue %s: %v", qURL, err)
	}
	spew.Fdump(w, msg.Messages)
}

func validateMessage(msg message) (*Event, error) {
	var (
		err error
		evt = Event{EventAttributes: EventAttributes{SenderID: msg.Attributes["SenderId"]}}
	)

	// Top-level fields
	if msg.MessageId == nil {
		return nil, fmt.Errorf("error: MessageId is nil")
	}
	evt.MessageID, err = uuid.Parse(*msg.MessageId)
	if err != nil {
		return nil, fmt.Errorf("error parsing MessageId as a UUID: %w", err)
	}
	if msg.MD5OfBody == nil {
		return nil, fmt.Errorf("error: MD5OfBody is nil")
	}
	evt.BodySum = *msg.MD5OfBody
	if msg.MD5OfMessageAttributes != nil {
		evt.AttrSum = *msg.MD5OfMessageAttributes
	}

	// EventAttributes fields
	msec, err := strconv.Atoi(msg.Attributes["SentTimestamp"])
	if err != nil {
		return nil, fmt.Errorf("error parsing SentTimestamp as milliseconds: %w", err)
	}
	evt.EventAttributes.SentTime = time.Unix(int64(msec)/1000, int64(msec)%1000)

	evt.EventAttributes.ApproximateReceiveCount, err = strconv.Atoi(msg.Attributes["ApproximateReceiveCount"])
	if err != nil {
		return nil, fmt.Errorf("error parsing ApproximateReceiveCount: %w", err)
	}
	msec, err = strconv.Atoi(msg.Attributes["ApproximateFirstReceiveTimestamp"])
	if err != nil {
		return nil, fmt.Errorf("error parsing ApproximateFirstReceiveTimestam as milliseconds: %w", err)
	}
	evt.EventAttributes.ApproximateFirstReceiveTimestamp = time.Unix(int64(msec)/1000, int64(msec)*1000)

	evt.MessageAttributes = msg.MessageAttributes

	// EventBody field
	if msg.Body == nil {
		return nil, fmt.Errorf("message body is nil")
	}
	body := *msg.Body
	bs, err := hex.DecodeString(evt.BodySum)
	if err != nil {
		return nil, fmt.Errorf("error parsy body sum as a hex string: %w", err)
	}
	expected := *(*[16]byte)(bs)
	actual := md5.Sum([]byte(body))
	if actual != expected {
		return nil, fmt.Errorf("error parsing body sum as a MD5 sum")
	}
	evt.Body = []byte(body)
	return &evt, nil
}