8000 Convert SQS test util into re-usable fixture sqs_collect_messages by joe4dev · Pull Request #11757 · localstack/localstack · GitHub
[go: up one dir, main page]

Skip to content

Convert SQS test util into re-usable fixture sqs_collect_messages #11757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion localstack-core/localstack/testing/pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import textwrap
import time
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple

import botocore.auth
import botocore.config
Expand Down Expand Up @@ -62,6 +62,11 @@
WAITER_STACK_DELETE_COMPLETE = "stack_delete_complete"


if TYPE_CHECKING:
from mypy_boto3_sqs import SQSClient
from mypy_boto3_sqs.type_defs import MessageTypeDef


@pytest.fixture(scope="class")
def aws_http_client_factory(aws_session):
"""
Expand Down Expand Up @@ -365,6 +370,65 @@ def factory(queue_url: str, expected_messages: int, max_iterations: int = 3):
return factory


@pytest.fixture
def sqs_collect_messages(aws_client):
"""Collects SQS messages from a given queue_url and deletes them by default.
Example usage:
messages = sqs_collect_messages(
my_queue_url,
expected=2,
timeout=10,
attribute_names=["All"],
message_attribute_names=["All"],
)
"""

def factory(
queue_url: str,
expected: int,
timeout: int,
delete: bool = True,
attribute_names: list[str] = None,
message_attribute_names: list[str] = None,
max_number_of_messages: int = 1,
wait_time_seconds: int = 5,
sqs_client: "SQSClient | None" = None,
) -> list["MessageTypeDef"]:
sqs_client = sqs_client or aws_client.sqs
collected = []

def _receive():
response = sqs_client.receive_message(
QueueUrl=queue_url,
# Maximum is 20 seconds. Performs long polling.
WaitTimeSeconds=wait_time_seconds,
# Maximum 10 messages
MaxNumberOfMessages=max_number_of_messages,
AttributeNames=attribute_names or [],
MessageAttributeNames=message_attribute_names or [],
)

if messages := response.get("Messages"):
collected.extend(messages)

if delete:
for m in messages:
sqs_client.delete_message(
QueueUrl=queue_url, ReceiptHandle=m["ReceiptHandle"]
)

return len(collected) >= expected

if not poll_condition(_receive, timeout=timeout):
raise TimeoutError(
f"gave up waiting for messages (expected={expected}, actual={len(collected)}"
)

return collected

yield factory


@pytest.fixture
def sqs_queue(sqs_create_queue):
return sqs_create_queue()
Expand Down
4 changes: 1 addition & 3 deletions tests/aws/services/sqs/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
from tests.aws.services.lambda_.functions import lambda_integration
from tests.aws.services.lambda_.test_lambda import TEST_LAMBDA_PYTHON

from .utils import sqs_collect_messages

if TYPE_CHECKING:
from mypy_boto3_sqs import SQSClient

Expand Down Expand Up @@ -2936,6 +2934,7 @@ def test_dead_letter_queue_message_attributes(
sqs_create_queue,
sqs_get_queue_arn,
snapshot,
sqs_collect_messages,
):
sqs = aws_client.sqs

Expand Down Expand Up @@ -2990,7 +2989,6 @@ def test_dead_letter_queue_message_attributes(
snapshot.match("rec-pre-dlq", messages)

messages = sqs_collect_messages(
sqs,
dl_queue_url,
expected=2,
timeout=10,
Expand Down
19 changes: 12 additions & 7 deletions tests/aws/services/sqs/test_sqs_move_task.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from localstack.utils. 8000 aws import arns
from localstack.utils.sync import retry

from .utils import sqs_collect_messages, sqs_wait_queue_size
from .utils import sqs_wait_queue_size

QueueUrl = str

Expand Down Expand Up @@ -125,6 +125,7 @@ def test_basic_move_task_workflow(
sqs_create_queue,
sqs_create_dlq_pipe,
sqs_get_queue_arn,
sqs_collect_messages,
aws_client,
snapshot,
):
Expand Down Expand Up @@ -161,7 +162,7 @@ def test_basic_move_task_workflow(
assert decoded_source_arn == source_arn

# check that messages arrived in destination queue correctly
messages = sqs_collect_messages(sqs, destination_queue, expected=2, timeout=10)
messages = sqs_collect_messages(destination_queue, expected=2, timeout=10)
assert {message["Body"] for message in messages} == {"message-1", "message-2"}

# check move task completion (in AWS, approximate number of messages may take a while to update)
Expand All @@ -184,6 +185,7 @@ def test_move_task_workflow_with_default_destination(
sqs_create_queue,
sqs_create_dlq_pipe,
sqs_get_queue_arn,
sqs_collect_messages,
aws_client,
snapshot,
):
Expand Down Expand Up @@ -221,7 +223,7 @@ def test_move_task_workflow_with_default_destination(
assert decoded_source_arn == source_arn

# check that messages arrived in destination queue correctly
messages = sqs_collect_messages(sqs, queue_url, expected=2, timeout=10)
messages = sqs_collect_messages(queue_url, expected=2, timeout=10)
assert {message["Body"] for message in messages} == {"message-1", "message-2"}

# check move task completion (in AWS, approximate number of messages may take a while to update)
Expand All @@ -244,6 +246,7 @@ def test_move_task_workflow_with_multiple_sources_as_default_destination(
sqs_create_queue,
sqs_create_dlq_pipe,
sqs_get_queue_arn,
sqs_collect_messages,
aws_client,
snapshot,
):
Expand Down Expand Up @@ -295,10 +298,10 @@ def test_move_task_workflow_with_multiple_sources_as_default_destination(
snapshot.match("start-message-move-task-response", response)

# check that messages arrived in destination queue correctly
messages = sqs_collect_messages(sqs, queue1_url, expected=2, timeout=10)
messages = sqs_collect_messages(queue1_url, expected=2, timeout=10)
assert {message["Body"] for message in messages} == {"message-1-1", "message-1-2"}

messages = sqs_collect_messages(sqs, queue2_url, expected=2, timeout=10)
messages = sqs_collect_messages(queue2_url, expected=2, timeout=10)
assert {message["Body"] for message in messages} == {"message-2-1", "message-2-2"}

# check move task completion (in AWS, approximate number of messages may take a while to update)
Expand All @@ -321,6 +324,7 @@ def test_move_task_with_throughput_limit(
sqs_create_queue,
sqs_create_dlq_pipe,
sqs_get_queue_arn,
sqs_collect_messages,
aws_client,
snapshot,
):
Expand Down Expand Up @@ -353,7 +357,7 @@ def test_move_task_with_throughput_limit(
)
snapshot.match("start-message-move-task-response", response)
started = time.time()
messages = sqs_collect_messages(sqs, destination_queue, n, 60)
messages = sqs_collect_messages(destination_queue, n, 60)
assert {message["Body"] for message in messages} == {
"message-0",
"message-1",
Expand All @@ -378,6 +382,7 @@ def test_move_task_cancel(
sqs_create_queue,
sqs_create_dlq_pipe,
sqs_get_queue_arn,
sqs_collect_messages,
aws_client,
snapshot,
):
Expand Down Expand Up @@ -411,7 +416,7 @@ def test_move_task_cancel(
task_handle = response["TaskHandle"]

# wait for two messages to arrive, then cancel the task
messages = sqs_collect_messages(sqs, destination_queue, 2, 60)
messages = sqs_collect_messages(destination_queue, 2, 60)
assert len(messages) == 2

response = sqs.list_message_move_tasks(SourceArn=source_arn)
Expand Down
39 changes: 0 additions & 39 deletions tests/aws/services/sqs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,6 @@

if TYPE_CHECKING:
from mypy_boto3_sqs import SQSClient
from mypy_boto3_sqs.type_defs import MessageTypeDef


def sqs_collect_messages(
sqs_client: "SQSClient",
queue_url: str,
expected: int,
timeout: int,
delete: bool = True,
attribute_names: list[str] = None,
message_attribute_names: list[str] = None,
) -> list["MessageTypeDef"]:
collected = []

def _receive():
response = sqs_client.receive_message(
QueueUrl=queue_url,
# try not to wait too long, but also not poll too often
WaitTimeSeconds=min(max(1, timeout), 5),
MaxNumberOfMessages=1,
AttributeNames=attribute_names or [],
MessageAttributeNames=message_attribute_names or [],
)

if messages := response.get("Messages"):
collected.extend(messages)

if delete:
for m in messages:
sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=m["ReceiptHandle"])

return len(collected) >= expected

if not poll_condition(_receive, timeout=timeout):
raise TimeoutError(
f"gave up waiting for messages (expected={expected}, actual={len(collected)}"
)

return collected


def get_approx_number_of_messages(
Expand Down
Loading
0