diff --git a/localstack-core/localstack/testing/pytest/fixtures.py b/localstack-core/localstack/testing/pytest/fixtures.py index 87823967336e7..3e6318a63f06a 100644 --- a/localstack-core/localstack/testing/pytest/fixtures.py +++ b/localstack-core/localstack/testing/pytest/fixtures.py @@ -6,7 +6,7 @@ import re import textwrap import time -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import botocore.auth import botocore.config @@ -62,6 +62,11 @@ WAITER_STACK_DELETE_COMPLETE = "stack_delete_complete" +if TYPE_CHECKING: + from mypy_boto3_sqs import SQSClient + from mypy_boto3_sqs.type_defs import MessageTypeDef + + @pytest.fixture(scope="class") def aws_http_client_factory(aws_session): """ @@ -365,6 +370,65 @@ def factory(queue_url: str, expected_messages: int, max_iterations: int = 3): return factory +@pytest.fixture +def sqs_collect_messages(aws_client): + """Collects SQS messages from a given queue_url and deletes them by default. + Example usage: + messages = sqs_collect_messages( + my_queue_url, + expected=2, + timeout=10, + attribute_names=["All"], + message_attribute_names=["All"], + ) + """ + + def factory( + queue_url: str, + expected: int, + timeout: int, + delete: bool = True, + attribute_names: list[str] = None, + message_attribute_names: list[str] = None, + max_number_of_messages: int = 1, + wait_time_seconds: int = 5, + sqs_client: "SQSClient | None" = None, + ) -> list["MessageTypeDef"]: + sqs_client = sqs_client or aws_client.sqs + collected = [] + + def _receive(): + response = sqs_client.receive_message( + QueueUrl=queue_url, + # Maximum is 20 seconds. Performs long polling. + WaitTimeSeconds=wait_time_seconds, + # Maximum 10 messages + MaxNumberOfMessages=max_number_of_messages, + AttributeNames=attribute_names or [], + MessageAttributeNames=message_attribute_names or [], + ) + + if messages := response.get("Messages"): + collected.extend(messages) + + if delete: + for m in messages: + sqs_client.delete_message( + QueueUrl=queue_url, ReceiptHandle=m["ReceiptHandle"] + ) + + return len(collected) >= expected + + if not poll_condition(_receive, timeout=timeout): + raise TimeoutError( + f"gave up waiting for messages (expected={expected}, actual={len(collected)}" + ) + + return collected + + yield factory + + @pytest.fixture def sqs_queue(sqs_create_queue): return sqs_create_queue() diff --git a/tests/aws/services/sqs/test_sqs.py b/tests/aws/services/sqs/test_sqs.py index d2adbf8af2d0a..2903612e58e73 100644 --- a/tests/aws/services/sqs/test_sqs.py +++ b/tests/aws/services/sqs/test_sqs.py @@ -33,8 +33,6 @@ from tests.aws.services.lambda_.functions import lambda_integration from tests.aws.services.lambda_.test_lambda import TEST_LAMBDA_PYTHON -from .utils import sqs_collect_messages - if TYPE_CHECKING: from mypy_boto3_sqs import SQSClient @@ -2936,6 +2934,7 @@ def test_dead_letter_queue_message_attributes( sqs_create_queue, sqs_get_queue_arn, snapshot, + sqs_collect_messages, ): sqs = aws_client.sqs @@ -2990,7 +2989,6 @@ def test_dead_letter_queue_message_attributes( snapshot.match("rec-pre-dlq", messages) messages = sqs_collect_messages( - sqs, dl_queue_url, expected=2, timeout=10, diff --git a/tests/aws/services/sqs/test_sqs_move_task.py b/tests/aws/services/sqs/test_sqs_move_task.py index f4cc2085a5ffe..5cc9d5841dbde 100644 --- a/tests/aws/services/sqs/test_sqs_move_task.py +++ b/tests/aws/services/sqs/test_sqs_move_task.py @@ -10,7 +10,7 @@ from localstack.utils.aws import arns from localstack.utils.sync import retry -from .utils import sqs_collect_messages, sqs_wait_queue_size +from .utils import sqs_wait_queue_size QueueUrl = str @@ -125,6 +125,7 @@ def test_basic_move_task_workflow( sqs_create_queue, sqs_create_dlq_pipe, sqs_get_queue_arn, + sqs_collect_messages, aws_client, snapshot, ): @@ -161,7 +162,7 @@ def test_basic_move_task_workflow( assert decoded_source_arn == source_arn # check that messages arrived in destination queue correctly - messages = sqs_collect_messages(sqs, destination_queue, expected=2, timeout=10) + messages = sqs_collect_messages(destination_queue, expected=2, timeout=10) assert {message["Body"] for message in messages} == {"message-1", "message-2"} # check move task completion (in AWS, approximate number of messages may take a while to update) @@ -184,6 +185,7 @@ def test_move_task_workflow_with_default_destination( sqs_create_queue, sqs_create_dlq_pipe, sqs_get_queue_arn, + sqs_collect_messages, aws_client, snapshot, ): @@ -221,7 +223,7 @@ def test_move_task_workflow_with_default_destination( assert decoded_source_arn == source_arn # check that messages arrived in destination queue correctly - messages = sqs_collect_messages(sqs, queue_url, expected=2, timeout=10) + messages = sqs_collect_messages(queue_url, expected=2, timeout=10) assert {message["Body"] for message in messages} == {"message-1", "message-2"} # check move task completion (in AWS, approximate number of messages may take a while to update) @@ -244,6 +246,7 @@ def test_move_task_workflow_with_multiple_sources_as_default_destination( sqs_create_queue, sqs_create_dlq_pipe, sqs_get_queue_arn, + sqs_collect_messages, aws_client, snapshot, ): @@ -295,10 +298,10 @@ def test_move_task_workflow_with_multiple_sources_as_default_destination( snapshot.match("start-message-move-task-response", response) # check that messages arrived in destination queue correctly - messages = sqs_collect_messages(sqs, queue1_url, expected=2, timeout=10) + messages = sqs_collect_messages(queue1_url, expected=2, timeout=10) assert {message["Body"] for message in messages} == {"message-1-1", "message-1-2"} - messages = sqs_collect_messages(sqs, queue2_url, expected=2, timeout=10) + messages = sqs_collect_messages(queue2_url, expected=2, timeout=10) assert {message["Body"] for message in messages} == {"message-2-1", "message-2-2"} # check move task completion (in AWS, approximate number of messages may take a while to update) @@ -321,6 +324,7 @@ def test_move_task_with_throughput_limit( sqs_create_queue, sqs_create_dlq_pipe, sqs_get_queue_arn, + sqs_collect_messages, aws_client, snapshot, ): @@ -353,7 +357,7 @@ def test_move_task_with_throughput_limit( ) snapshot.match("start-message-move-task-response", response) started = time.time() - messages = sqs_collect_messages(sqs, destination_queue, n, 60) + messages = sqs_collect_messages(destination_queue, n, 60) assert {message["Body"] for message in messages} == { "message-0", "message-1", @@ -378,6 +382,7 @@ def test_move_task_cancel( sqs_create_queue, sqs_create_dlq_pipe, sqs_get_queue_arn, + sqs_collect_messages, aws_client, snapshot, ): @@ -411,7 +416,7 @@ def test_move_task_cancel( task_handle = response["TaskHandle"] # wait for two messages to arrive, then cancel the task - messages = sqs_collect_messages(sqs, destination_queue, 2, 60) + messages = sqs_collect_messages(destination_queue, 2, 60) assert len(messages) == 2 response = sqs.list_message_move_tasks(SourceArn=source_arn) diff --git a/tests/aws/services/sqs/utils.py b/tests/aws/services/sqs/utils.py index 6cde5b4881897..6887d44f61d9a 100644 --- a/tests/aws/services/sqs/utils.py +++ b/tests/aws/services/sqs/utils.py @@ -4,45 +4,6 @@ if TYPE_CHECKING: from mypy_boto3_sqs import SQSClient - from mypy_boto3_sqs.type_defs import MessageTypeDef - - -def sqs_collect_messages( - sqs_client: "SQSClient", - queue_url: str, - expected: int, - timeout: int, - delete: bool = True, - attribute_names: list[str] = None, - message_attribute_names: list[str] = None, -) -> list["MessageTypeDef"]: - collected = [] - - def _receive(): - response = sqs_client.receive_message( - QueueUrl=queue_url, - # try not to wait too long, but also not poll too often - WaitTimeSeconds=min(max(1, timeout), 5), - MaxNumberOfMessages=1, - AttributeNames=attribute_names or [], - MessageAttributeNames=message_attribute_names or [], - ) - - if messages := response.get("Messages"): - collected.extend(messages) - - if delete: - for m in messages: - sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=m["ReceiptHandle"]) - - return len(collected) >= expected - - if not poll_condition(_receive, timeout=timeout): - raise TimeoutError( - f"gave up waiting for messages (expected={expected}, actual={len(collected)}" - ) - - return collected def get_approx_number_of_messages(