diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 3aea2b70fa4..4cdea6d28f5 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -37,6 +37,7 @@ KinesisStreamRecord, ) from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.parser import ValidationError from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -316,21 +317,36 @@ def _get_messages_to_report(self) -> List[Dict[str, str]]: def _collect_sqs_failures(self): failures = [] for msg in self.fail_messages: - msg_id = msg.messageId if self.model else msg.message_id + # If a message failed due to model validation (e.g., poison pill) + # we convert to an event source data class...but self.model is still true + # therefore, we do an additional check on whether the failed message is still a model + # see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091 + if self.model and getattr(msg, "parse_obj", None): + msg_id = msg.messageId + else: + msg_id = msg.message_id failures.append({"itemIdentifier": msg_id}) return failures def _collect_kinesis_failures(self): failures = [] for msg in self.fail_messages: - msg_id = msg.kinesis.sequenceNumber if self.model else msg.kinesis.sequence_number + # # see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091 + if self.model and getattr(msg, "parse_obj", None): + msg_id = msg.kinesis.sequenceNumber + else: + msg_id = msg.kinesis.sequence_number failures.append({"itemIdentifier": msg_id}) return failures def _collect_dynamodb_failures(self): failures = [] for msg in self.fail_messages: - msg_id = msg.dynamodb.SequenceNumber if self.model else msg.dynamodb.sequence_number + # see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091 + if self.model and getattr(msg, "parse_obj", None): + msg_id = msg.dynamodb.SequenceNumber + else: + msg_id = msg.dynamodb.sequence_number failures.append({"itemIdentifier": msg_id}) return failures @@ -347,6 +363,17 @@ def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["B return model.parse_obj(record) return self._DATA_CLASS_MAPPING[event_type](record) + def _register_model_validation_error_record(self, record: dict): + """Convert and register failure due to poison pills where model failed validation early""" + # Parser will fail validation if record is a poison pill (malformed input) + # this means we can't collect the message id if we try transforming again + # so we convert into to the equivalent batch type model (e.g., SQS, Kinesis, DynamoDB Stream) + # and downstream we can correctly collect the correct message id identifier and make the failed record available + # see https://github.com/awslabs/aws-lambda-powertools-python/issues/2091 + logger.debug("Record cannot be converted to customer's model; converting without model") + failed_record: "EventSourceDataClassTypes" = self._to_batch_type(record=record, event_type=self.event_type) + return self.failure_handler(record=failed_record, exception=sys.exc_info()) + class BatchProcessor(BasePartialBatchProcessor): # Keep old name for compatibility """Process native partial responses from SQS, Kinesis Data Streams, and DynamoDB. @@ -471,14 +498,17 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons record: dict A batch record to be processed. """ - data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) + data: Optional["BatchTypeModels"] = None try: + data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) if self._handler_accepts_lambda_context: result = self.handler(record=data, lambda_context=self.lambda_context) else: result = self.handler(record=data) return self.success_handler(record=record, result=result) + except ValidationError: + return self._register_model_validation_error_record(record) except Exception: return self.failure_handler(record=data, exception=sys.exc_info()) @@ -651,14 +681,17 @@ async def _async_process_record(self, record: dict) -> Union[SuccessResponse, Fa record: dict A batch record to be processed. """ - data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) + data: Optional["BatchTypeModels"] = None try: + data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) if self._handler_accepts_lambda_context: result = await self.handler(record=data, lambda_context=self.lambda_context) else: result = await self.handler(record=data) return self.success_handler(record=record, result=result) + except ValidationError: + return self._register_model_validation_error_record(record) except Exception: return self.failure_handler(record=data, exception=sys.exc_info()) diff --git a/aws_lambda_powertools/utilities/parser/models/dynamodb.py b/aws_lambda_powertools/utilities/parser/models/dynamodb.py index 772b8fb580f..4c85c72d438 100644 --- a/aws_lambda_powertools/utilities/parser/models/dynamodb.py +++ b/aws_lambda_powertools/utilities/parser/models/dynamodb.py @@ -9,8 +9,8 @@ class DynamoDBStreamChangedRecordModel(BaseModel): ApproximateCreationDateTime: Optional[date] Keys: Dict[str, Dict[str, Any]] - NewImage: Optional[Union[Dict[str, Any], Type[BaseModel]]] - OldImage: Optional[Union[Dict[str, Any], Type[BaseModel]]] + NewImage: Optional[Union[Dict[str, Any], Type[BaseModel], BaseModel]] + OldImage: Optional[Union[Dict[str, Any], Type[BaseModel], BaseModel]] SequenceNumber: str SizeBytes: int StreamViewType: Literal["NEW_AND_OLD_IMAGES", "KEYS_ONLY", "NEW_IMAGE", "OLD_IMAGE"] diff --git a/aws_lambda_powertools/utilities/parser/models/kinesis.py b/aws_lambda_powertools/utilities/parser/models/kinesis.py index 6fb9a7076b5..bb6d6b5318f 100644 --- a/aws_lambda_powertools/utilities/parser/models/kinesis.py +++ b/aws_lambda_powertools/utilities/parser/models/kinesis.py @@ -15,7 +15,7 @@ class KinesisDataStreamRecordPayload(BaseModel): kinesisSchemaVersion: str partitionKey: str sequenceNumber: str - data: Union[bytes, Type[BaseModel]] # base64 encoded str is parsed into bytes + data: Union[bytes, Type[BaseModel], BaseModel] # base64 encoded str is parsed into bytes approximateArrivalTimestamp: float @validator("data", pre=True, allow_reuse=True) diff --git a/aws_lambda_powertools/utilities/parser/models/sqs.py b/aws_lambda_powertools/utilities/parser/models/sqs.py index 1d56c4f8e34..c92a8361b7c 100644 --- a/aws_lambda_powertools/utilities/parser/models/sqs.py +++ b/aws_lambda_powertools/utilities/parser/models/sqs.py @@ -52,7 +52,7 @@ class SqsMsgAttributeModel(BaseModel): class SqsRecordModel(BaseModel): messageId: str receiptHandle: str - body: Union[str, Type[BaseModel]] + body: Union[str, Type[BaseModel], BaseModel] attributes: SqsAttributesModel messageAttributes: Dict[str, SqsMsgAttributeModel] md5OfBody: str diff --git a/aws_lambda_powertools/utilities/parser/types.py b/aws_lambda_powertools/utilities/parser/types.py index e9acceb8963..d3f00646d52 100644 --- a/aws_lambda_powertools/utilities/parser/types.py +++ b/aws_lambda_powertools/utilities/parser/types.py @@ -3,16 +3,18 @@ import sys from typing import Any, Dict, Type, TypeVar, Union -from pydantic import BaseModel +from pydantic import BaseModel, Json # We only need typing_extensions for python versions <3.8 if sys.version_info >= (3, 8): - from typing import Literal # noqa: F401 + from typing import Literal else: - from typing_extensions import Literal # noqa: F401 + from typing_extensions import Literal Model = TypeVar("Model", bound=BaseModel) EnvelopeModel = TypeVar("EnvelopeModel") EventParserReturnType = TypeVar("EventParserReturnType") AnyInheritedModel = Union[Type[BaseModel], BaseModel] RawDictOrModel = Union[Dict[str, Any], AnyInheritedModel] + +__all__ = ["Json", "Literal"] diff --git a/tests/functional/batch/__init__.py b/tests/functional/batch/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/functional/batch/sample_models.py b/tests/functional/batch/sample_models.py new file mode 100644 index 00000000000..556ff0ebf8a --- /dev/null +++ b/tests/functional/batch/sample_models.py @@ -0,0 +1,47 @@ +import json +from typing import Dict, Optional + +from aws_lambda_powertools.utilities.parser import BaseModel, validator +from aws_lambda_powertools.utilities.parser.models import ( + DynamoDBStreamChangedRecordModel, + DynamoDBStreamRecordModel, + KinesisDataStreamRecord, + KinesisDataStreamRecordPayload, + SqsRecordModel, +) +from aws_lambda_powertools.utilities.parser.types import Json, Literal + + +class Order(BaseModel): + item: dict + + +class OrderSqs(SqsRecordModel): + body: Json[Order] + + +class OrderKinesisPayloadRecord(KinesisDataStreamRecordPayload): + data: Json[Order] + + +class OrderKinesisRecord(KinesisDataStreamRecord): + kinesis: OrderKinesisPayloadRecord + + +class OrderDynamoDB(BaseModel): + Message: Order + + # auto transform json string + # so Pydantic can auto-initialize nested Order model + @validator("Message", pre=True) + def transform_message_to_dict(cls, value: Dict[Literal["S"], str]): + return json.loads(value["S"]) + + +class OrderDynamoDBChangeRecord(DynamoDBStreamChangedRecordModel): + NewImage: Optional[OrderDynamoDB] + OldImage: Optional[OrderDynamoDB] + + +class OrderDynamoDBRecord(DynamoDBStreamRecordModel): + dynamodb: OrderDynamoDBChangeRecord diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index c98d59a7042..2205d47660c 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -27,14 +27,12 @@ DynamoDBStreamChangedRecordModel, DynamoDBStreamRecordModel, ) -from aws_lambda_powertools.utilities.parser.models import ( - KinesisDataStreamRecord as KinesisDataStreamRecordModel, -) -from aws_lambda_powertools.utilities.parser.models import ( - KinesisDataStreamRecordPayload, - SqsRecordModel, -) from aws_lambda_powertools.utilities.parser.types import Literal +from tests.functional.batch.sample_models import ( + OrderDynamoDBRecord, + OrderKinesisRecord, + OrderSqs, +) from tests.functional.utils import b64_to_str, str_to_b64 @@ -119,6 +117,16 @@ def handler(record): return handler +@pytest.fixture(scope="module") +def record_handler_model() -> Callable: + def record_handler(record: OrderSqs): + if "fail" in record.body.item["type"]: + raise Exception("Failed to process record.") + return record.body.item + + return record_handler + + @pytest.fixture(scope="module") def async_record_handler() -> Callable[..., Awaitable[Any]]: async def handler(record): @@ -130,6 +138,16 @@ async def handler(record): return handler +@pytest.fixture(scope="module") +def async_record_handler_model() -> Callable[..., Awaitable[Any]]: + async def async_record_handler(record: OrderSqs): + if "fail" in record.body.item["type"]: + raise ValueError("Failed to process record.") + return record.body.item + + return async_record_handler + + @pytest.fixture(scope="module") def kinesis_record_handler() -> Callable: def handler(record: KinesisStreamRecord): @@ -141,17 +159,57 @@ def handler(record: KinesisStreamRecord): return handler +@pytest.fixture(scope="module") +def kinesis_record_handler_model() -> Callable: + def record_handler(record: OrderKinesisRecord): + if "fail" in record.kinesis.data.item["type"]: + raise ValueError("Failed to process record.") + return record.kinesis.data.item + + return record_handler + + +@pytest.fixture(scope="module") +def async_kinesis_record_handler_model() -> Callable[..., Awaitable[Any]]: + async def record_handler(record: OrderKinesisRecord): + if "fail" in record.kinesis.data.item["type"]: + raise Exception("Failed to process record.") + return record.kinesis.data.item + + return record_handler + + @pytest.fixture(scope="module") def dynamodb_record_handler() -> Callable: def handler(record: DynamoDBRecord): body = record.dynamodb.new_image.get("Message") if "fail" in body: - raise Exception("Failed to process record.") + raise ValueError("Failed to process record.") return body return handler +@pytest.fixture(scope="module") +def dynamodb_record_handler_model() -> Callable: + def record_handler(record: OrderDynamoDBRecord): + if "fail" in record.dynamodb.NewImage.Message.item["type"]: + raise ValueError("Failed to process record.") + return record.dynamodb.NewImage.Message.item + + return record_handler + + +@pytest.fixture(scope="module") +def async_dynamodb_record_handler() -> Callable[..., Awaitable[Any]]: + async def record_handler(record: OrderDynamoDBRecord): + if "fail" in record.dynamodb.NewImage.Message.item["type"]: + raise ValueError("Failed to process record.") + return record.dynamodb.NewImage.Message.item + + return record_handler + + @pytest.fixture(scope="module") def config() -> Config: return Config(region_name="us-east-1") @@ -374,18 +432,6 @@ def lambda_handler(event, context): def test_batch_processor_context_model(sqs_event_factory, order_event_factory): # GIVEN - class Order(BaseModel): - item: dict - - class OrderSqs(SqsRecordModel): - body: Order - - # auto transform json string - # so Pydantic can auto-initialize nested Order model - @validator("body", pre=True) - def transform_body_to_dict(cls, value: str): - return json.loads(value) - def record_handler(record: OrderSqs): return record.body.item @@ -411,18 +457,6 @@ def record_handler(record: OrderSqs): def test_batch_processor_context_model_with_failure(sqs_event_factory, order_event_factory): # GIVEN - class Order(BaseModel): - item: dict - - class OrderSqs(SqsRecordModel): - body: Order - - # auto transform json string - # so Pydantic can auto-initialize nested Order model - @validator("body", pre=True) - def transform_body_to_dict(cls, value: str): - return json.loads(value) - def record_handler(record: OrderSqs): if "fail" in record.body.item["type"]: raise Exception("Failed to process record.") @@ -542,27 +576,10 @@ def record_handler(record: OrderDynamoDBRecord): } -def test_batch_processor_kinesis_context_parser_model(kinesis_event_factory, order_event_factory): +def test_batch_processor_kinesis_context_parser_model( + kinesis_record_handler_model: Callable, kinesis_event_factory, order_event_factory +): # GIVEN - class Order(BaseModel): - item: dict - - class OrderKinesisPayloadRecord(KinesisDataStreamRecordPayload): - data: Order - - # auto transform json string - # so Pydantic can auto-initialize nested Order model - @validator("data", pre=True) - def transform_message_to_dict(cls, value: str): - # Powertools KinesisDataStreamRecordModel already decodes b64 to str here - return json.loads(value) - - class OrderKinesisRecord(KinesisDataStreamRecordModel): - kinesis: OrderKinesisPayloadRecord - - def record_handler(record: OrderKinesisRecord): - return record.kinesis.data.item - order_event = order_event_factory({"type": "success"}) first_record = kinesis_event_factory(order_event) second_record = kinesis_event_factory(order_event) @@ -570,7 +587,7 @@ def record_handler(record: OrderKinesisRecord): # WHEN processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord) - with processor(records, record_handler) as batch: + with processor(records, kinesis_record_handler_model) as batch: processed_messages = batch.process() # THEN @@ -583,29 +600,10 @@ def record_handler(record: OrderKinesisRecord): assert batch.response() == {"batchItemFailures": []} -def test_batch_processor_kinesis_context_parser_model_with_failure(kinesis_event_factory, order_event_factory): +def test_batch_processor_kinesis_context_parser_model_with_failure( + kinesis_record_handler_model: Callable, kinesis_event_factory, order_event_factory +): # GIVEN - class Order(BaseModel): - item: dict - - class OrderKinesisPayloadRecord(KinesisDataStreamRecordPayload): - data: Order - - # auto transform json string - # so Pydantic can auto-initialize nested Order model - @validator("data", pre=True) - def transform_message_to_dict(cls, value: str): - # Powertools KinesisDataStreamRecordModel - return json.loads(value) - - class OrderKinesisRecord(KinesisDataStreamRecordModel): - kinesis: OrderKinesisPayloadRecord - - def record_handler(record: OrderKinesisRecord): - if "fail" in record.kinesis.data.item["type"]: - raise Exception("Failed to process record.") - return record.kinesis.data.item - order_event = order_event_factory({"type": "success"}) order_event_fail = order_event_factory({"type": "fail"}) @@ -616,7 +614,7 @@ def record_handler(record: OrderKinesisRecord): # WHEN processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord) - with processor(records, record_handler) as batch: + with processor(records, kinesis_record_handler_model) as batch: batch.process() # THEN @@ -775,3 +773,147 @@ def test_async_batch_processor_context_with_failure(sqs_event_factory, async_rec assert batch.response() == { "batchItemFailures": [{"itemIdentifier": first_record.message_id}, {"itemIdentifier": third_record.message_id}] } + + +def test_batch_processor_model_with_partial_validation_error( + record_handler_model: Callable, sqs_event_factory, order_event_factory +): + # GIVEN + order_event = order_event_factory({"type": "success"}) + first_record = sqs_event_factory(order_event) + second_record = sqs_event_factory(order_event) + malformed_record = sqs_event_factory({"poison": "pill"}) + records = [first_record, malformed_record, second_record] + + # WHEN + processor = BatchProcessor(event_type=EventType.SQS, model=OrderSqs) + with processor(records, record_handler_model) as batch: + batch.process() + + # THEN + assert len(batch.fail_messages) == 1 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": malformed_record["messageId"]}, + ] + } + + +def test_batch_processor_dynamodb_context_model_with_partial_validation_error( + dynamodb_record_handler_model: Callable, dynamodb_event_factory, order_event_factory +): + # GIVEN + order_event = order_event_factory({"type": "success"}) + first_record = dynamodb_event_factory(order_event) + second_record = dynamodb_event_factory(order_event) + malformed_record = dynamodb_event_factory({"poison": "pill"}) + records = [first_record, malformed_record, second_record] + + # WHEN + processor = BatchProcessor(event_type=EventType.DynamoDBStreams, model=OrderDynamoDBRecord) + with processor(records, dynamodb_record_handler_model) as batch: + batch.process() + + # THEN + assert len(batch.fail_messages) == 1 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": malformed_record["dynamodb"]["SequenceNumber"]}, + ] + } + + +def test_batch_processor_kinesis_context_parser_model_with_partial_validation_error( + kinesis_record_handler_model: Callable, kinesis_event_factory, order_event_factory +): + # GIVEN + order_event = order_event_factory({"type": "success"}) + first_record = kinesis_event_factory(order_event) + second_record = kinesis_event_factory(order_event) + malformed_record = kinesis_event_factory('{"poison": "pill"}') + records = [first_record, malformed_record, second_record] + + # WHEN + processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord) + with processor(records, kinesis_record_handler_model) as batch: + batch.process() + + # THEN + assert len(batch.fail_messages) == 1 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": malformed_record["kinesis"]["sequenceNumber"]}, + ] + } + + +def test_async_batch_processor_model_with_partial_validation_error( + async_record_handler_model: Callable, sqs_event_factory, order_event_factory +): + # GIVEN + order_event = order_event_factory({"type": "success"}) + first_record = sqs_event_factory(order_event) + second_record = sqs_event_factory(order_event) + malformed_record = sqs_event_factory({"poison": "pill"}) + records = [first_record, malformed_record, second_record] + + # WHEN + processor = AsyncBatchProcessor(event_type=EventType.SQS, model=OrderSqs) + with processor(records, async_record_handler_model) as batch: + batch.async_process() + + # THEN + assert len(batch.fail_messages) == 1 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": malformed_record["messageId"]}, + ] + } + + +def test_async_batch_processor_dynamodb_context_model_with_partial_validation_error( + async_dynamodb_record_handler: Callable, dynamodb_event_factory, order_event_factory +): + # GIVEN + order_event = order_event_factory({"type": "success"}) + first_record = dynamodb_event_factory(order_event) + second_record = dynamodb_event_factory(order_event) + malformed_record = dynamodb_event_factory({"poison": "pill"}) + records = [first_record, malformed_record, second_record] + + # WHEN + processor = AsyncBatchProcessor(event_type=EventType.DynamoDBStreams, model=OrderDynamoDBRecord) + with processor(records, async_dynamodb_record_handler) as batch: + batch.async_process() + + # THEN + assert len(batch.fail_messages) == 1 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": malformed_record["dynamodb"]["SequenceNumber"]}, + ] + } + + +def test_async_batch_processor_kinesis_context_parser_model_with_partial_validation_error( + async_kinesis_record_handler_model: Callable, kinesis_event_factory, order_event_factory +): + # GIVEN + order_event = order_event_factory({"type": "success"}) + first_record = kinesis_event_factory(order_event) + second_record = kinesis_event_factory(order_event) + malformed_record = kinesis_event_factory('{"poison": "pill"}') + records = [first_record, malformed_record, second_record] + + # WHEN + processor = AsyncBatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord) + with processor(records, async_kinesis_record_handler_model) as batch: + batch.async_process() + + # THEN + assert len(batch.fail_messages) == 1 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": malformed_record["kinesis"]["sequenceNumber"]}, + ] + }