8000 fix message retry in lambda SQS event source mapping by thrau · Pull Request #6603 · localstack/localstack · GitHub
[go: up one dir, main page]

Skip to content

fix message retry in lambda SQS event source mapping #6603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def _listener_loop(self, *args):
self.SQS_LISTENER_THREAD.pop("_thread_")
return

unprocessed_messages = {}

for source in sources:
queue_arn = source["EventSourceArn"]
region_name = queue_arn.split(":")[3]
Expand All @@ -59,21 +57,17 @@ def _listener_loop(self, *args):

try:
queue_url = aws_stack.sqs_queue_url_for_arn(queue_arn)
messages = unprocessed_messages.pop(queue_arn, None)
result = sqs_client.receive_message(
QueueUrl=queue_url,
AttributeNames=["All"],
MessageAttributeNames=["All"],
MaxNumberOfMessages=batch_size,
)
Comment on lines +60 to +65
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure but I think this needs to have a value set for WaitTimeSeconds since per default the SQS poller uses long polling (for non-FIFO queues at least) according to https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html#events-sqs-scaling

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In localstack that might not be particularly relevant though because AFAIK we don't differentiate between short and long polling anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totally agree. unfortunately we can't block here because of the way the loop is structured. see comment here: #5042 (comment)

that would require a thread per event-source-mapping, which again requires a different interface for the event source listener (adding/removing mappings). i'd love to do that as part of the lambda rework.

messages = result.get("Messages")
if not messages:
result = sqs_client.receive_message(
QueueUrl=queue_url,
AttributeNames=["All"],
MessageAttributeNames=["All"],
MaxNumberOfMessages=batch_size,
)
messages = result.get("Messages")
if not messages:
continue

res = self._process_messages_for_event_source(source, messages)
if not res:
unprocessed_messages[queue_arn] = messages
continue

self._process_messages_for_event_source(source, messages)

except Exception as e:
if "NonExistentQueue" not in str(e):
Expand All @@ -100,12 +94,12 @@ def _process_messages_for_event_source(self, source, messages):
)
return res

def _send_event_to_lambda(self, queue_arn, queue_url, lambda_arn, messages, region):
def delete_messages(result, func_arn, event, error=None, dlq_sent=None, **kwargs):
if error and not dlq_sent:
# Skip deleting messages from the queue in case of processing errors AND if
# the message has not yet been sent to a dead letter queue (DLQ).
# We'll pick them up and retry next time they become available on the queue.
def _send_event_to_lambda(self, queue_arn, queue_url, lambda_arn, messages, region) -> bool:
def delete_messages(result, func_arn, event, error=None, **kwargs):
if error:
# Skip deleting messages from the queue in case of processing errors. We'll pick them up and retry
# next time they become visible in the queue. Redrive policies will be handled automatically by SQS
# on the next polling attempt.
return

region_name = queue_arn.split(":")[3]
Expand All @@ -124,25 +118,25 @@ def delete_messages(result, func_arn, event, error=None, dlq_sent=None, **kwargs
records = []
for msg in messages:
message_attrs = message_attributes_to_lower(msg.get("MessageAttributes"))
records.append(
{
"body": msg.get("Body", "MessageBody"),
"receiptHandle": msg.get("ReceiptHandle"),
"md5OfBody": msg.get("MD5OfBody") or msg.get("MD5OfMessageBody"),
"eventSourceARN": queue_arn,
"eventSource": lambda_executors.EVENT_SOURCE_SQS,
"awsRegion": region,
"messageId": msg["MessageId"],
"attributes": msg.get("Attributes", {}),
"messageAttributes": message_attrs,
"md5OfMessageAttributes": msg.get("MD5OfMessageAttributes"),
"sqs": True,
}
)
record = {
"body": msg.get("Body", "MessageBody"),
"receiptHandle": msg.get("ReceiptHandle"),
"md5OfBody": msg.get("MD5OfBody") or msg.get("MD5OfMessageBody"),
"eventSourceARN": queue_arn,
"eventSource": lambda_executors.EVENT_SOURCE_SQS,
"awsRegion": region,
"messageId": msg["MessageId"],
"attributes": msg.get("Attributes", {}),
"messageAttributes": message_attrs,
}

if md5OfMessageAttributes := msg.get("MD5OfMessageAttributes"):
record["md5OfMessageAttributes"] = md5OfMessageAttributes

records.append(record)

event = {"Records": records}

# TODO implement retries, based on "RedrivePolicy.maxReceiveCount" in the queue settings
res = run_lambda(
func_arn=lambda_arn,
event=event,
Expand Down
24 changes: 6 additions & 18 deletions localstack/services/awslambda/lambda_executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
LAMBDA_RUNTIME_PROVIDED,
get_container_network_for_lambda,
get_main_endpoint_from_container,
get_record_from_event,
is_java_lambda,
is_nodejs_runtime,
rm_docker_container,
Expand All @@ -34,10 +33,7 @@
from localstack.services.install import GO_LAMBDA_RUNTIME, INSTALL_PATH_LOCALSTACK_FAT_JAR
from localstack.utils.aws import aws_stack
from localstack.utils.aws.aws_models import LambdaFunction
from localstack.utils.aws.dead_letter_queue import (
lambda_error_to_dead_letter_queue,
sqs_error_to_dead_letter_queue,
)
from localstack.utils.aws.dead_letter_queue import lambda_error_to_dead_letter_queue
from localstack.utils.cloudwatch.cloudwatch_util import cloudwatched
from localstack.utils.collections import select_attributes
from localstack.utils.common import (
Expand Down Expand Up @@ -354,14 +350,7 @@ def handle_error(
lambda_function: LambdaFunction, event: Dict, error: Exception, asynchronous: bool = False
):
if asynchronous:
if get_record_from_event(event, "eventSource") == EVENT_SOURCE_SQS:
sqs_queue_arn = get_record_from_event(event, "eventSourceARN")
if sqs_queue_arn:
# event source is SQS, send event back to dead letter queue
return sqs_error_to_dead_letter_queue(sqs_queue_arn, event, error)
else:
# event source is not SQS, send back to lambda dead letter queue
lambda_error_to_dead_letter_queue(lambda_function, event, error)
lambda_error_to_dead_letter_queue(lambda_function, event, error)


class LambdaAsyncLocks:
Expand Down Expand Up @@ -425,7 +414,6 @@ def _run(func_arn=None):
# start the execution
raised_error = None
result = None
dlq_sent = None
invocation_type = "Event" if asynchronous else "RequestResponse"
inv_context = InvocationContext(
lambda_function,
Expand All @@ -438,13 +426,13 @@ def _run(func_arn=None):
result = self._execute(lambda_function, inv_context)
except Exception as e:
raised_error = e
dlq_sent = handle_error(lambda_function, event, e, asynchronous)
handle_error(lambda_function, event, e, asynchronous)
raise e
finally:
self.function_invoke_times[func_arn] = invocation_time
callback and callback(
result, func_arn, event, error=raised_error, dlq_sent=dlq_sent
)
if callback:
callback(result, func_arn, event, error=raised_error)

lambda_result_to_destination(
lambda_function, event, result, asynchronous, raised_error
)
Expand Down
8 changes: 8 additions & 0 deletions localstack/testing/snapshots/transformer_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def key_value(
replace_reference=reference_replacement,
)

@staticmethod
def resource_name(replacement_name: str = "resource"):
"""Creates a new KeyValueBasedTransformer for the resource name.

:return: KeyValueBasedTransformer
"""
return KeyValueBasedTransformer(_resource_name_transformer, replacement_name)

@staticmethod
def jsonpath(jsonpath: str, value_replacement: str, reference_replacement: bool = True):
"""Creates a new JsonpathTransformer. If the jsonpath matches, the value will be replaced.
Expand Down
32 changes: 3 additions & 29 deletions localstack/utils/aws/dead_letter_queue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
import uuid
from json import JSONDecodeError
from typing import Dict, List, Union

from localstack.utils.aws import aws_stack
Expand All @@ -11,47 +10,22 @@
LOG = logging.getLogger(__name__)


def sqs_error_to_dead_letter_queue(queue_arn: str, event: Dict, error):
client = aws_stack.connect_to_service("sqs")
queue_url = aws_stack.get_sqs_queue_url(queue_arn)
attrs = client.get_queue_attributes(QueueUrl=queue_url, AttributeNames=["RedrivePolicy"])
attrs = attrs.get("Attributes", {})
try:
policy = json.loads(attrs.get("RedrivePolicy") or "{}")
except JSONDecodeError:
LOG.warning(
"Parsing RedrivePolicy {} failed, Queue: {}".format(
attrs.get("RedrivePolicy"), queue_arn
)
)
return

target_arn = policy.get("deadLetterTargetArn")
if not target_arn:
return
return _send_to_dead_letter_queue("SQS", queue_arn, target_arn, event, error)


def sns_error_to_dead_letter_queue(sns_subscriber: dict, event: str, error):
# event should be of type str if coming from SNS, as it represents the message body being passed down
policy = json.loads(sns_subscriber.get("RedrivePolicy") or "{}")
target_arn = policy.get("deadLetterTargetArn")
if not target_arn:
return
return _send_to_dead_letter_queue(
"SNS", sns_subscriber["SubscriptionArn"], target_arn, event, error
)
return _send_to_dead_letter_queue(sns_subscriber["SubscriptionArn"], target_arn, event, error)


def lambda_error_to_dead_letter_queue(func_details: LambdaFunction, event: Dict, error):
dlq_arn = (func_details.dead_letter_config or {}).get("TargetArn")
source_arn = func_details.id
return _send_to_dead_letter_queue("Lambda", source_arn, dlq_arn, event, error)
return _send_to_dead_letter_queue(source_arn, dlq_arn, event, error)


def _send_to_dead_letter_queue(
source_type: str, source_arn: str, dlq_arn: str, event: Union[Dict, str], error
):
def _send_to_dead_letter_queue(source_arn: str, dlq_arn: str, event: Union[Dict, str], error):
if not dlq_arn:
return
LOG.info("Sending failed execution %s to dead letter queue %s", source_arn, dlq_arn)
Expand Down
2 changes: 1 addition & 1 deletion localstack/utils/aws/message_forwarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def lambda_result_to_destination(
event: Dict,
result: InvocationResult,
is_async: bool,
error: InvocationException,
error: Optional[InvocationException],
):
if not func_details.destination_enabled():
return
Expand Down
52 changes: 52 additions & 0 deletions tests/integration/awslambda/functions/lambda_sqs_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""This lambda is used for lambda/sqs integration tests. Since SQS event source mappings don't allow
DestinationConfigurations that send lambda results to other source (like SQS queues), that can be used to verify
invocations, this lambda does this manually. You can pass in an event that looks like this::

{
"destination": "<q FFAD ueue_url>",
"fail_attempts": 2
}

Which will cause the lambda to fail twice (comparing the "ApproximateReceiveCount" of the SQS event triggering
the lambda), and send either an error or success result to the SQS queue passed in the destination key.
"""
import json
import os

import boto3


def handler(event, context):
# this lambda expects inputs from an SQS event source mapping
if len(event.get("Records", [])) != 1:
raise ValueError("the payload must consist of exactly one record")

# it expects exactly one record where the message body is '{"destination": "<queue_url>"}' that mimics a
# DestinationConfig (which is not possible with SQS event source mappings).
record = event["Records"][0]
message = json.loads(record["body"])

if not message.get("destination"):
raise ValueError("no destination for the event given")

error = None
try:
if message["fail_attempts"] >= int(record["attributes"]["ApproximateReceiveCount"]):
raise ValueError("failed attempt")
except Exception as e:
error = e
raise
finally:
# we then send a message to the destination queue
result = {"error": None if not error else str(error), "event": event}
sqs = create_external_boto_client("sqs")
sqs.send_message(QueueUrl=message.get("destination"), MessageBody=json.dumps(result))


def create_external_boto_client(service):
endpoint_url = None
if os.environ.get("LOCALSTACK_HOSTNAME"):
endpoint_url = (
f"http://{os.environ['LOCALSTACK_HOSTNAME']}:{os.environ.get('EDGE_PORT', 4566)}"
)
return boto3.client(service, endpoint_url=endpoint_url)
2 changes: 2 additions & 0 deletions tests/integration/awslambda/test_lambda_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@


class TestSQSEventSourceMapping:
# FIXME: refactor and move to test_lambda_sqs_integration

@pytest.mark.skip_snapshot_verify
def test_event_source_mapping_default_batch_size(
self,
Expand Down
Loading
0