From d254cd2827cdb6d571523508d4d0df938b9d6ce2 Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Tue, 24 May 2022 11:22:36 +0200 Subject: [PATCH 01/16] initial prototype --- localstack/aws/api/kinesis/__init__.py | 845 ++++++++++++++++++ .../services/kinesis/kinesis_starter.py | 17 +- localstack/services/kinesis/provider.py | 374 ++++++++ localstack/services/providers.py | 16 +- 4 files changed, 1244 insertions(+), 8 deletions(-) create mode 100644 localstack/aws/api/kinesis/__init__.py create mode 100644 localstack/services/kinesis/provider.py diff --git a/localstack/aws/api/kinesis/__init__.py b/localstack/aws/api/kinesis/__init__.py new file mode 100644 index 0000000000000..75ad27d107783 --- /dev/null +++ b/localstack/aws/api/kinesis/__init__.py @@ -0,0 +1,845 @@ +import sys +from datetime import datetime +from typing import Dict, Iterator, List, Optional + +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + +from localstack.aws.api import RequestContext, ServiceException, ServiceRequest, handler + +BooleanObject = bool +ConsumerARN = str +ConsumerCountObject = int +ConsumerName = str +DescribeStreamInputLimit = int +ErrorCode = str +ErrorMessage = str +GetRecordsInputLimit = int +HashKey = str +KeyId = str +ListShardsInputLimit = int +ListStreamConsumersInputLimit = int +ListStreamsInputLimit = int +ListTagsForStreamInputLimit = int +NextToken = str +OnDemandStreamCountLimitObject = int +OnDemandStreamCountObject = int +PartitionKey = str +PositiveIntegerObject = int +RetentionPeriodHours = int +SequenceNumber = str +ShardCountObject = int +ShardId = str +ShardIterator = str +StreamARN = str +StreamName = str +TagKey = str +TagValue = str + + +class ConsumerStatus(str): + CREATING = "CREATING" + DELETING = "DELETING" + ACTIVE = "ACTIVE" + + +class EncryptionType(str): + NONE = "NONE" + KMS = "KMS" + + +class MetricsName(str): + IncomingBytes = "IncomingBytes" + IncomingRecords = "IncomingRecords" + OutgoingBytes = "OutgoingBytes" + OutgoingRecords = "OutgoingRecords" + WriteProvisionedThroughputExceeded = "WriteProvisionedThroughputExceeded" + ReadProvisionedThroughputExceeded = "ReadProvisionedThroughputExceeded" + IteratorAgeMilliseconds = "IteratorAgeMilliseconds" + ALL = "ALL" + + +class ScalingType(str): + UNIFORM_SCALING = "UNIFORM_SCALING" + + +class ShardFilterType(str): + AFTER_SHARD_ID = "AFTER_SHARD_ID" + AT_TRIM_HORIZON = "AT_TRIM_HORIZON" + FROM_TRIM_HORIZON = "FROM_TRIM_HORIZON" + AT_LATEST = "AT_LATEST" + AT_TIMESTAMP = "AT_TIMESTAMP" + FROM_TIMESTAMP = "FROM_TIMESTAMP" + + +class ShardIteratorType(str): + AT_SEQUENCE_NUMBER = "AT_SEQUENCE_NUMBER" + AFTER_SEQUENCE_NUMBER = "AFTER_SEQUENCE_NUMBER" + TRIM_HORIZON = "TRIM_HORIZON" + LATEST = "LATEST" + AT_TIMESTAMP = "AT_TIMESTAMP" + + +class StreamMode(str): + PROVISIONED = "PROVISIONED" + ON_DEMAND = "ON_DEMAND" + + +class StreamStatus(str): + CREATING = "CREATING" + DELETING = "DELETING" + ACTIVE = "ACTIVE" + UPDATING = "UPDATING" + + +class ExpiredIteratorException(ServiceException): + message: Optional[ErrorMessage] + + +class ExpiredNextTokenException(ServiceException): + message: Optional[ErrorMessage] + + +class InternalFailureException(ServiceException): + message: Optional[ErrorMessage] + + +class InvalidArgumentException(ServiceException): + message: Optional[ErrorMessage] + + +class KMSAccessDeniedException(ServiceException): + message: Optional[ErrorMessage] + + +class KMSDisabledException(ServiceException): + message: Optional[ErrorMessage] + + +class KMSInvalidStateException(ServiceException): + message: Optional[ErrorMessage] + + +class KMSNotFoundException(ServiceException): + message: Optional[ErrorMessage] + + +class KMSOptInRequired(ServiceException): + message: Optional[ErrorMessage] + + +class KMSThrottlingException(ServiceException): + message: Optional[ErrorMessage] + + +class LimitExceededException(ServiceException): + message: Optional[ErrorMessage] + + +class ProvisionedThroughputExceededException(ServiceException): + message: Optional[ErrorMessage] + + +class ResourceInUseException(ServiceException): + message: Optional[ErrorMessage] + + +class ResourceNotFoundException(ServiceException): + message: Optional[ErrorMessage] + + +class ValidationException(ServiceException): + message: Optional[ErrorMessage] + + +TagMap = Dict[TagKey, TagValue] + + +class AddTagsToStreamInput(ServiceRequest): + StreamName: StreamName + Tags: TagMap + + +class HashKeyRange(TypedDict, total=False): + StartingHashKey: HashKey + EndingHashKey: HashKey + + +ShardIdList = List[ShardId] + + +class ChildShard(TypedDict, total=False): + ShardId: ShardId + ParentShards: ShardIdList + HashKeyRange: HashKeyRange + + +ChildShardList = List[ChildShard] +Timestamp = datetime + + +class Consumer(TypedDict, total=False): + ConsumerName: ConsumerName + ConsumerARN: ConsumerARN + ConsumerStatus: ConsumerStatus + ConsumerCreationTimestamp: Timestamp + + +class ConsumerDescription(TypedDict, total=False): + ConsumerName: ConsumerName + ConsumerARN: ConsumerARN + ConsumerStatus: ConsumerStatus + ConsumerCreationTimestamp: Timestamp + StreamARN: StreamARN + + +ConsumerList = List[Consumer] + + +class StreamModeDetails(TypedDict, total=False): + StreamMode: StreamMode + + +class CreateStreamInput(ServiceRequest): + StreamName: StreamName + ShardCount: Optional[PositiveIntegerObject] + StreamModeDetails: Optional[StreamModeDetails] + + +Data = bytes + + +class DecreaseStreamRetentionPeriodInput(ServiceRequest): + StreamName: StreamName + RetentionPeriodHours: RetentionPeriodHours + + +class DeleteStreamInput(ServiceRequest): + StreamName: StreamName + EnforceConsumerDeletion: Optional[BooleanObject] + + +class DeregisterStreamConsumerInput(ServiceRequest): + StreamARN: Optional[StreamARN] + ConsumerName: Optional[ConsumerName] + ConsumerARN: Optional[ConsumerARN] + + +class DescribeLimitsInput(ServiceRequest): + pass + + +class DescribeLimitsOutput(TypedDict, total=False): + ShardLimit: ShardCountObject + OpenShardCount: ShardCountObject + OnDemandStreamCount: OnDemandStreamCountObject + OnDemandStreamCountLimit: OnDemandStreamCountLimitObject + + +class DescribeStreamConsumerInput(ServiceRequest): + StreamARN: Optional[StreamARN] + ConsumerName: Optional[ConsumerName] + ConsumerARN: Optional[ConsumerARN] + + +class DescribeStreamConsumerOutput(TypedDict, total=False): + ConsumerDescription: ConsumerDescription + + +class DescribeStreamInput(ServiceRequest): + StreamName: StreamName + Limit: Optional[DescribeStreamInputLimit] + ExclusiveStartShardId: Optional[ShardId] + + +MetricsNameList = List[MetricsName] + + +class EnhancedMetrics(TypedDict, total=False): + ShardLevelMetrics: Optional[MetricsNameList] + + +EnhancedMonitoringList = List[EnhancedMetrics] + + +class SequenceNumberRange(TypedDict, total=False): + StartingSequenceNumber: SequenceNumber + EndingSequenceNumber: Optional[SequenceNumber] + + +class Shard(TypedDict, total=False): + ShardId: ShardId + ParentShardId: Optional[ShardId] + AdjacentParentShardId: Optional[ShardId] + HashKeyRange: HashKeyRange + SequenceNumberRange: SequenceNumberRange + + +ShardList = List[Shard] + + +class StreamDescription(TypedDict, total=False): + StreamName: StreamName + StreamARN: StreamARN + StreamStatus: StreamStatus + StreamModeDetails: Optional[StreamModeDetails] + Shards: ShardList + HasMoreShards: BooleanObject + RetentionPeriodHours: RetentionPeriodHours + StreamCreationTimestamp: Timestamp + EnhancedMonitoring: EnhancedMonitoringList + EncryptionType: Optional[EncryptionType] + KeyId: Optional[KeyId] + + +class DescribeStreamOutput(TypedDict, total=False): + StreamDescription: StreamDescription + + +class DescribeStreamSummaryInput(ServiceRequest): + StreamName: StreamName + + +class StreamDescriptionSummary(TypedDict, total=False): + StreamName: StreamName + StreamARN: StreamARN + StreamStatus: StreamStatus + StreamModeDetails: Optional[StreamModeDetails] + RetentionPeriodHours: RetentionPeriodHours + StreamCreationTimestamp: Timestamp + EnhancedMonitoring: EnhancedMonitoringList + EncryptionType: Optional[EncryptionType] + KeyId: Optional[KeyId] + OpenShardCount: ShardCountObject + ConsumerCount: Optional[ConsumerCountObject] + + +class DescribeStreamSummaryOutput(TypedDict, total=False): + StreamDescriptionSummary: StreamDescriptionSummary + + +class DisableEnhancedMonitoringInput(ServiceRequest): + StreamName: StreamName + ShardLevelMetrics: MetricsNameList + + +class EnableEnhancedMonitoringInput(ServiceRequest): + StreamName: StreamName + ShardLevelMetrics: MetricsNameList + + +class EnhancedMonitoringOutput(TypedDict, total=False): + StreamName: Optional[StreamName] + CurrentShardLevelMetrics: Optional[MetricsNameList] + DesiredShardLevelMetrics: Optional[MetricsNameList] + + +class GetRecordsInput(ServiceRequest): + ShardIterator: ShardIterator + Limit: Optional[GetRecordsInputLimit] + + +MillisBehindLatest = int + + +class Record(TypedDict, total=False): + SequenceNumber: SequenceNumber + ApproximateArrivalTimestamp: Optional[Timestamp] + Data: Data + PartitionKey: PartitionKey + EncryptionType: Optional[EncryptionType] + + +RecordList = List[Record] + + +class GetRecordsOutput(TypedDict, total=False): + Records: RecordList + NextShardIterator: Optional[ShardIterator] + MillisBehindLatest: Optional[MillisBehindLatest] + ChildShards: Optional[ChildShardList] + + +class GetShardIteratorInput(ServiceRequest): + StreamName: StreamName + ShardId: ShardId + ShardIteratorType: ShardIteratorType + StartingSequenceNumber: Optional[SequenceNumber] + Timestamp: Optional[Timestamp] + + +class GetShardIteratorOutput(TypedDict, total=False): + ShardIterator: Optional[ShardIterator] + + +class IncreaseStreamRetentionPeriodInput(ServiceRequest): + StreamName: StreamName + RetentionPeriodHours: RetentionPeriodHours + + +class ShardFilter(TypedDict, total=False): + Type: ShardFilterType + ShardId: Optional[ShardId] + Timestamp: Optional[Timestamp] + + +class ListShardsInput(ServiceRequest): + StreamName: Optional[StreamName] + NextToken: Optional[NextToken] + ExclusiveStartShardId: Optional[ShardId] + MaxResults: Optional[ListShardsInputLimit] + StreamCreationTimestamp: Optional[Timestamp] + ShardFilter: Optional[ShardFilter] + + +class ListShardsOutput(TypedDict, total=False): + Shards: Optional[ShardList] + NextToken: Optional[NextToken] + + +class ListStreamConsumersInput(ServiceRequest): + StreamARN: StreamARN + NextToken: Optional[NextToken] + MaxResults: Optional[ListStreamConsumersInputLimit] + StreamCreationTimestamp: Optional[Timestamp] + + +class ListStreamConsumersOutput(TypedDict, total=False): + Consumers: Optional[ConsumerList] + NextToken: Optional[NextToken] + + +class ListStreamsInput(ServiceRequest): + Limit: Optional[ListStreamsInputLimit] + ExclusiveStartStreamName: Optional[StreamName] + + +StreamNameList = List[StreamName] + + +class ListStreamsOutput(TypedDict, total=False): + StreamNames: StreamNameList + HasMoreStreams: BooleanObject + + +class ListTagsForStreamInput(ServiceRequest): + StreamName: StreamName + ExclusiveStartTagKey: Optional[TagKey] + Limit: Optional[ListTagsForStreamInputLimit] + + +class Tag(TypedDict, total=False): + Key: TagKey + Value: Optional[TagValue] + + +TagList = List[Tag] + + +class ListTagsForStreamOutput(TypedDict, total=False): + Tags: TagList + HasMoreTags: BooleanObject + + +class MergeShardsInput(ServiceRequest): + StreamName: StreamName + ShardToMerge: ShardId + AdjacentShardToMerge: ShardId + + +class PutRecordInput(ServiceRequest): + StreamName: StreamName + Data: Data + PartitionKey: PartitionKey + ExplicitHashKey: Optional[HashKey] + SequenceNumberForOrdering: Optional[SequenceNumber] + + +class PutRecordOutput(TypedDict, total=False): + ShardId: ShardId + SequenceNumber: SequenceNumber + EncryptionType: Optional[EncryptionType] + + +class PutRecordsRequestEntry(TypedDict, total=False): + Data: Data + ExplicitHashKey: Optional[HashKey] + PartitionKey: PartitionKey + + +PutRecordsRequestEntryList = List[PutRecordsRequestEntry] + + +class PutRecordsInput(ServiceRequest): + Records: PutRecordsRequestEntryList + StreamName: StreamName + + +class PutRecordsResultEntry(TypedDict, total=False): + SequenceNumber: Optional[SequenceNumber] + ShardId: Optional[ShardId] + ErrorCode: Optional[ErrorCode] + ErrorMessage: Optional[ErrorMessage] + + +PutRecordsResultEntryList = List[PutRecordsResultEntry] + + +class PutRecordsOutput(TypedDict, total=False): + FailedRecordCount: Optional[PositiveIntegerObject] + Records: PutRecordsResultEntryList + EncryptionType: Optional[EncryptionType] + + +class RegisterStreamConsumerInput(ServiceRequest): + StreamARN: StreamARN + ConsumerName: ConsumerName + + +class RegisterStreamConsumerOutput(TypedDict, total=False): + Consumer: Consumer + + +TagKeyList = List[TagKey] + + +class RemoveTagsFromStreamInput(ServiceRequest): + StreamName: StreamName + TagKeys: TagKeyList + + +class SplitShardInput(ServiceRequest): + StreamName: StreamName + ShardToSplit: ShardId + NewStartingHashKey: HashKey + + +class StartStreamEncryptionInput(ServiceRequest): + StreamName: StreamName + EncryptionType: EncryptionType + KeyId: KeyId + + +class StartingPosition(TypedDict, total=False): + Type: ShardIteratorType + SequenceNumber: Optional[SequenceNumber] + Timestamp: Optional[Timestamp] + + +class StopStreamEncryptionInput(ServiceRequest): + StreamName: StreamName + EncryptionType: EncryptionType + KeyId: KeyId + + +class SubscribeToShardEvent(TypedDict, total=False): + Records: RecordList + ContinuationSequenceNumber: SequenceNumber + MillisBehindLatest: MillisBehindLatest + ChildShards: Optional[ChildShardList] + + +class SubscribeToShardEventStream(TypedDict, total=False): + SubscribeToShardEvent: SubscribeToShardEvent + ResourceNotFoundException: Optional[ResourceNotFoundException] + ResourceInUseException: Optional[ResourceInUseException] + KMSDisabledException: Optional[KMSDisabledException] + KMSInvalidStateException: Optional[KMSInvalidStateException] + KMSAccessDeniedException: Optional[KMSAccessDeniedException] + KMSNotFoundException: Optional[KMSNotFoundException] + KMSOptInRequired: Optional[KMSOptInRequired] + KMSThrottlingException: Optional[KMSThrottlingException] + InternalFailureException: Optional[InternalFailureException] + + +class SubscribeToShardInput(ServiceRequest): + ConsumerARN: ConsumerARN + ShardId: ShardId + StartingPosition: StartingPosition + + +class SubscribeToShardOutput(TypedDict, total=False): + EventStream: Iterator[SubscribeToShardEventStream] + + +class UpdateShardCountInput(ServiceRequest): + StreamName: StreamName + TargetShardCount: PositiveIntegerObject + ScalingType: ScalingType + + +class UpdateShardCountOutput(TypedDict, total=False): + StreamName: Optional[StreamName] + CurrentShardCount: Optional[PositiveIntegerObject] + TargetShardCount: Optional[PositiveIntegerObject] + + +class UpdateStreamModeInput(ServiceRequest): + StreamARN: StreamARN + StreamModeDetails: StreamModeDetails + + +class KinesisApi: + + service = "kinesis" + version = "2013-12-02" + + @handler("AddTagsToStream") + def add_tags_to_stream( + self, context: RequestContext, stream_name: StreamName, tags: TagMap + ) -> None: + raise NotImplementedError + + @handler("CreateStream") + def create_stream( + self, + context: RequestContext, + stream_name: StreamName, + shard_count: PositiveIntegerObject = None, + stream_mode_details: StreamModeDetails = None, + ) -> None: + raise NotImplementedError + + @handler("DecreaseStreamRetentionPeriod") + def decrease_stream_retention_period( + self, + context: RequestContext, + stream_name: StreamName, + retention_period_hours: RetentionPeriodHours, + ) -> None: + raise NotImplementedError + + @handler("DeleteStream") + def delete_stream( + self, + context: RequestContext, + stream_name: StreamName, + enforce_consumer_deletion: BooleanObject = None, + ) -> None: + raise NotImplementedError + + @handler("DeregisterStreamConsumer") + def deregister_stream_consumer( + self, + context: RequestContext, + stream_arn: StreamARN = None, + consumer_name: ConsumerName = None, + consumer_arn: ConsumerARN = None, + ) -> None: + raise NotImplementedError + + @handler("DescribeLimits") + def describe_limits( + self, + context: RequestContext, + ) -> DescribeLimitsOutput: + raise NotImplementedError + + @handler("DescribeStream") + def describe_stream( + self, + context: RequestContext, + stream_name: StreamName, + limit: DescribeStreamInputLimit = None, + exclusive_start_shard_id: ShardId = None, + ) -> DescribeStreamOutput: + raise NotImplementedError + + @handler("DescribeStreamConsumer") + def describe_stream_consumer( + self, + context: RequestContext, + stream_arn: StreamARN = None, + consumer_name: ConsumerName = None, + consumer_arn: ConsumerARN = None, + ) -> DescribeStreamConsumerOutput: + raise NotImplementedError + + @handler("DescribeStreamSummary") + def describe_stream_summary( + self, context: RequestContext, stream_name: StreamName + ) -> DescribeStreamSummaryOutput: + raise NotImplementedError + + @handler("DisableEnhancedMonitoring") + def disable_enhanced_monitoring( + self, context: RequestContext, stream_name: StreamName, shard_level_metrics: MetricsNameList + ) -> EnhancedMonitoringOutput: + raise NotImplementedError + + @handler("EnableEnhancedMonitoring") + def enable_enhanced_monitoring( + self, context: RequestContext, stream_name: StreamName, shard_level_metrics: MetricsNameList + ) -> EnhancedMonitoringOutput: + raise NotImplementedError + + @handler("GetRecords") + def get_records( + self, + context: RequestContext, + shard_iterator: ShardIterator, + limit: GetRecordsInputLimit = None, + ) -> GetRecordsOutput: + raise NotImplementedError + + @handler("GetShardIterator") + def get_shard_iterator( + self, + context: RequestContext, + stream_name: StreamName, + shard_id: ShardId, + shard_iterator_type: ShardIteratorType, + starting_sequence_number: SequenceNumber = None, + timestamp: Timestamp = None, + ) -> GetShardIteratorOutput: + raise NotImplementedError + + @handler("IncreaseStreamRetentionPeriod") + def increase_stream_retention_period( + self, + context: RequestContext, + stream_name: StreamName, + retention_period_hours: RetentionPeriodHours, + ) -> None: + raise NotImplementedError + + @handler("ListShards") + def list_shards( + self, + context: RequestContext, + stream_name: StreamName = None, + next_token: NextToken = None, + exclusive_start_shard_id: ShardId = None, + max_results: ListShardsInputLimit = None, + stream_creation_timestamp: Timestamp = None, + shard_filter: ShardFilter = None, + ) -> ListShardsOutput: + raise NotImplementedError + + @handler("ListStreamConsumers") + def list_stream_consumers( + self, + context: RequestContext, + stream_arn: StreamARN, + next_token: NextToken = None, + max_results: ListStreamConsumersInputLimit = None, + stream_creation_timestamp: Timestamp = None, + ) -> ListStreamConsumersOutput: + raise NotImplementedError + + @handler("ListStreams") + def list_streams( + self, + context: RequestContext, + limit: ListStreamsInputLimit = None, + exclusive_start_stream_name: StreamName = None, + ) -> ListStreamsOutput: + raise NotImplementedError + + @handler("ListTagsForStream") + def list_tags_for_stream( + self, + context: RequestContext, + stream_name: StreamName, + exclusive_start_tag_key: TagKey = None, + limit: ListTagsForStreamInputLimit = None, + ) -> ListTagsForStreamOutput: + raise NotImplementedError + + @handler("MergeShards") + def merge_shards( + self, + context: RequestContext, + stream_name: StreamName, + shard_to_merge: ShardId, + adjacent_shard_to_merge: ShardId, + ) -> None: + raise NotImplementedError + + @handler("PutRecord") + def put_record( + self, + context: RequestContext, + stream_name: StreamName, + data: Data, + partition_key: PartitionKey, + explicit_hash_key: HashKey = None, + sequence_number_for_ordering: SequenceNumber = None, + ) -> PutRecordOutput: + raise NotImplementedError + + @handler("PutRecords") + def put_records( + self, context: RequestContext, records: PutRecordsRequestEntryList, stream_name: StreamName + ) -> PutRecordsOutput: + raise NotImplementedError + + @handler("RegisterStreamConsumer") + def register_stream_consumer( + self, context: RequestContext, stream_arn: StreamARN, consumer_name: ConsumerName + ) -> RegisterStreamConsumerOutput: + raise NotImplementedError + + @handler("RemoveTagsFromStream") + def remove_tags_from_stream( + self, context: RequestContext, stream_name: StreamName, tag_keys: TagKeyList + ) -> None: + raise NotImplementedError + + @handler("SplitShard") + def split_shard( + self, + context: RequestContext, + stream_name: StreamName, + shard_to_split: ShardId, + new_starting_hash_key: HashKey, + ) -> None: + raise NotImplementedError + + @handler("StartStreamEncryption") + def start_stream_encryption( + self, + context: RequestContext, + stream_name: StreamName, + encryption_type: EncryptionType, + key_id: KeyId, + ) -> None: + raise NotImplementedError + + @handler("StopStreamEncryption") + def stop_stream_encryption( + self, + context: RequestContext, + stream_name: StreamName, + encryption_type: EncryptionType, + key_id: KeyId, + ) -> None: + raise NotImplementedError + + @handler("SubscribeToShard") + def subscribe_to_shard( + self, + context: RequestContext, + consumer_arn: ConsumerARN, + shard_id: ShardId, + starting_position: StartingPosition, + ) -> SubscribeToShardOutput: + raise NotImplementedError + + @handler("UpdateShardCount") + def update_shard_count( + self, + context: RequestContext, + stream_name: StreamName, + target_shard_count: PositiveIntegerObject, + scaling_type: ScalingType, + ) -> UpdateShardCountOutput: + raise NotImplementedError + + @handler("UpdateStreamMode") + def update_stream_mode( + self, context: RequestContext, stream_arn: StreamARN, stream_mode_details: StreamModeDetails + ) -> None: + raise NotImplementedError diff --git a/localstack/services/kinesis/kinesis_starter.py b/localstack/services/kinesis/kinesis_starter.py index e84f9c9d39dad..b1c8d5813b5a3 100644 --- a/localstack/services/kinesis/kinesis_starter.py +++ b/localstack/services/kinesis/kinesis_starter.py @@ -4,6 +4,7 @@ from localstack import config from localstack.services.infra import log_startup_message, start_proxy_for_service from localstack.services.kinesis import kinesalite_server, kinesis_mock_server +from localstack.services.plugins import SERVICE_PLUGINS from localstack.utils.aws import aws_stack from localstack.utils.serving import Server @@ -38,12 +39,16 @@ def start_kinesis( _server.start() log_startup_message("Kinesis") port = port or config.service_port("kinesis") - start_proxy_for_service( - "kinesis", - port, - backend_port=_server.port, - update_listener=update_listener, - ) + + # TODO: flip back to "!= kinesis:asf" to be sure we have the old control path when merging + if SERVICE_PLUGINS.get("kinesis").name() == "kinesis:legacy": + start_proxy_for_service( + "kinesis", + port, + backend_port=_server.port, + update_listener=update_listener, + ) + return _server diff --git a/localstack/services/kinesis/provider.py b/localstack/services/kinesis/provider.py new file mode 100644 index 0000000000000..54946de3ef2c2 --- /dev/null +++ b/localstack/services/kinesis/provider.py @@ -0,0 +1,374 @@ +import logging +import time +from collections import defaultdict +from datetime import datetime +from random import random +from typing import Dict, List, Set + +from localstack import config +from localstack.aws.api import RequestContext +from localstack.aws.api.kinesis import ( + BooleanObject, + Consumer, + ConsumerARN, + ConsumerDescription, + ConsumerName, + ConsumerStatus, + Data, + DescribeStreamConsumerOutput, + EnhancedMonitoringOutput, + HashKey, + KinesisApi, + ListStreamConsumersInputLimit, + ListStreamConsumersOutput, + MetricsName, + MetricsNameList, + NextToken, + PartitionKey, + PositiveIntegerObject, + ProvisionedThroughputExceededException, + PutRecordsOutput, + PutRecordsRequestEntryList, + PutRecordsResultEntry, + RegisterStreamConsumerOutput, + ResourceInUseException, + ResourceNotFoundException, + ScalingType, + SequenceNumber, + ShardId, + StartingPosition, + StreamARN, + StreamModeDetails, + StreamName, + SubscribeToShardEvent, + SubscribeToShardEventStream, + SubscribeToShardOutput, + Timestamp, + UpdateShardCountOutput, +) +from localstack.aws.forwarder import HttpFallbackDispatcher +from localstack.aws.proxy import AwsApiListener +from localstack.constants import LOCALHOST +from localstack.services.generic_proxy import RegionBackend +from localstack.services.kinesis.kinesis_starter import check_kinesis, start_kinesis +from localstack.services.plugins import ServiceLifecycleHook +from localstack.utils.analytics import event_publisher +from localstack.utils.aws import aws_stack + +LOG = logging.getLogger(__name__) + +# TODO ASF: Check if we need to implement CBOR encoding in the serializer! +# TODO ASF: Set "X-Amzn-Errortype" (HEADER_AMZN_ERROR_TYPE) on responses +# TODO ASF: Rewrite responses +# - Region in content of responses +# - Record rewriting: +# - SDKv2: Transform timestamps to int? +# - Remove double quotes for JSON responses +# - Convert base64 encoded data back to bytes for the cbor encoding + + +class KinesisApiListener(AwsApiListener): + def __init__(self, provider=None): + provider = provider or KinesisProvider() + self.provider = provider + super().__init__("kinesis", HttpFallbackDispatcher(provider, provider.get_forward_url)) + + +class KinesisBackend(RegionBackend): + def __init__(self): + # list of stream consumer details + self.stream_consumers: List[ConsumerDescription] = [] + # maps stream name to list of enhanced monitoring metrics + self.enhanced_metrics: Dict[StreamName, Set[MetricsName]] = defaultdict(set) + + +def find_stream_for_consumer(consumer_arn): + kinesis = aws_stack.connect_to_service("kinesis") + for stream_name in kinesis.list_streams()["StreamNames"]: + stream_arn = aws_stack.kinesis_stream_arn(stream_name) + for cons in kinesis.list_stream_consumers(StreamARN=stream_arn)["Consumers"]: + if cons["ConsumerARN"] == consumer_arn: + return stream_name + raise Exception("Unable to find stream for stream consumer %s" % consumer_arn) + + +def find_consumer(consumer_arn="", consumer_name="", stream_arn=""): + stream_consumers = KinesisBackend.get().stream_consumers + for consumer in stream_consumers: + if consumer_arn and consumer_arn == consumer.get("ConsumerARN"): + return consumer + elif consumer_name == consumer.get("ConsumerName") and stream_arn == consumer.get( + "StreamARN" + ): + return consumer + + +class KinesisProvider(KinesisApi, ServiceLifecycleHook): + def __init__(self): + self._server = None + + def on_before_start(self): + self._server = start_kinesis() + check_kinesis() + + def get_forward_url(self): + """Return the URL of the backend Kinesis server to forward requests to""" + return f"http://{LOCALHOST}:{self._server.port}" + + def subscribe_to_shard( + self, + context: RequestContext, + consumer_arn: ConsumerARN, + shard_id: ShardId, + starting_position: StartingPosition, + ) -> SubscribeToShardOutput: + kinesis = aws_stack.connect_to_service("kinesis") + stream_name = find_stream_for_consumer(consumer_arn) + iter_type = starting_position["Type"] + kwargs = {} + starting_sequence_number = starting_position.get("SequenceNumber") or "0" + if iter_type in ["AT_SEQUENCE_NUMBER", "AFTER_SEQUENCE_NUMBER"]: + kwargs["StartingSequenceNumber"] = starting_sequence_number + elif iter_type in ["AT_TIMESTAMP"]: + # or value is just an example timestamp from aws docs + timestamp = starting_position.get("Timestamp") or 1459799926.480 + kwargs["Timestamp"] = timestamp + initial_shard_iterator = kinesis.get_shard_iterator( + StreamName=stream_name, ShardId=shard_id, ShardIteratorType=iter_type, **kwargs + )["ShardIterator"] + + def event_generator(): + shard_iterator = initial_shard_iterator + last_sequence_number = starting_sequence_number + # TODO: find better way to run loop up to max 5 minutes (until connection terminates)! + for i in range(5 * 60): + try: + result = kinesis.get_records(ShardIterator=shard_iterator) + except Exception as e: + if "ResourceNotFoundException" in str(e): + LOG.debug( + 'Kinesis stream "%s" has been deleted, closing shard subscriber', + stream_name, + ) + return + raise + shard_iterator = result.get("NextShardIterator") + records = result.get("Records", []) + if not records: + time.sleep(1) + continue + + yield SubscribeToShardEventStream( + SubscribeToShardEvent=SubscribeToShardEvent( + Records=records, + ContinuationSequenceNumber=str(last_sequence_number), + MillisBehindLatest=0, + ChildShards=[], + ) + ) + + return SubscribeToShardOutput(EventStream=event_generator()) + + def put_record( + self, + context: RequestContext, + stream_name: StreamName, + data: Data, + partition_key: PartitionKey, + explicit_hash_key: HashKey = None, + sequence_number_for_ordering: SequenceNumber = None, + ): + if random() < config.KINESIS_ERROR_PROBABILITY: + raise ProvisionedThroughputExceededException( + "Rate exceeded for shard X in stream Y under account Z." + ) + # If "we were lucky" and the error probability didn't hit, we raise a NotImplementedError in order to + # trigger the fallback to kinesis-mock or kinesalite + raise NotImplementedError + + def put_records( + self, context: RequestContext, records: PutRecordsRequestEntryList, stream_name: StreamName + ) -> PutRecordsOutput: + if random() < config.KINESIS_ERROR_PROBABILITY: + records_count = len(records) if records is not None else 0 + records = [ + PutRecordsResultEntry( + ErrorCode="ProvisionedThroughputExceededException", + ErrorMessage="Rate exceeded for shard X in stream Y under account Z.", + ) + ] * records_count + return PutRecordsOutput(FailedRecordCount=1, Records=records) + # If "we were lucky" and the error probability didn't hit, we raise a NotImplementedError in order to + # trigger the fallback to kinesis-mock or kinesalite + raise NotImplementedError + + def register_stream_consumer( + self, context: RequestContext, stream_arn: StreamARN, consumer_name: ConsumerName + ) -> RegisterStreamConsumerOutput: + if config.KINESIS_PROVIDER == "kinesalite": + prev_consumer = find_consumer(stream_arn=stream_arn, consumer_name=consumer_name) + if prev_consumer: + raise ResourceInUseException( + f"Consumer {prev_consumer['ConsumerARN']} already exists" + ) + consumer = Consumer( + ConsumerName=consumer_name, + ConsumerStatus=ConsumerStatus.ACTIVE, + ConsumerARN=f"{stream_arn}/consumer/{consumer_name}", + ConsumerCreationTimestamp=datetime.now(), + ) + consumer_description = ConsumerDescription(**consumer) + consumer_description["StreamARN"] = stream_arn + KinesisBackend.get().stream_consumers.append(consumer_description) + return RegisterStreamConsumerOutput(Consumer=consumer) + + # If kinesis-mock is used, we forward the request through the fallback by raising a NotImplementedError + raise NotImplementedError + + def deregister_stream_consumer( + self, + context: RequestContext, + stream_arn: StreamARN = "", + consumer_name: ConsumerName = "", + consumer_arn: ConsumerARN = "", + ) -> None: + if config.KINESIS_PROVIDER == "kinesalite": + + def consumer_filter(consumer: ConsumerDescription): + return not ( + consumer.get("ConsumerARN") == consumer_arn + or ( + consumer.get("StreamARN") == stream_arn + and consumer.get("ConsumerName") == consumer_name + ) + ) + + region = KinesisBackend.get() + region.stream_consumers = list(filter(consumer_filter, region.stream_consumers)) + return None + + # If kinesis-mock is used, we forward the request through the fallback by raising a NotImplementedError + raise NotImplementedError + + def list_stream_consumers( + self, + context: RequestContext, + stream_arn: StreamARN, + next_token: NextToken = None, + max_results: ListStreamConsumersInputLimit = None, + stream_creation_timestamp: Timestamp = None, + ) -> ListStreamConsumersOutput: + if config.KINESIS_PROVIDER == "kinesalite": + stream_consumers = KinesisBackend.get().stream_consumers + consumers: List[Consumer] = [] + for consumer_description in stream_consumers: + consumer = Consumer( + ConsumerARN=consumer_description["ConsumerARN"], + ConsumerCreationTimestamp=consumer_description["ConsumerCreationTimestamp"], + ConsumerName=consumer_description["ConsumerName"], + ConsumerStatus=consumer_description["ConsumerStatus"], + ) + consumers.append(consumer) + return ListStreamConsumersOutput(Consumers=consumers) + + # If kinesis-mock is used, we forward the request through the fallback by raising a NotImplementedError + raise NotImplementedError + + def describe_stream_consumer( + self, + context: RequestContext, + stream_arn: StreamARN = None, + consumer_name: ConsumerName = None, + consumer_arn: ConsumerARN = None, + ) -> DescribeStreamConsumerOutput: + if config.KINESIS_PROVIDER == "kinesalite": + consumer_to_locate = find_consumer(consumer_arn, consumer_name, stream_arn) + if not consumer_to_locate: + raise ResourceNotFoundException( + f"Consumer {consumer_arn or consumer_name} not found." + ) + return DescribeStreamConsumerOutput(ConsumerDescription=consumer_to_locate) + + # If kinesis-mock is used, we forward the request through the fallback by raising a NotImplementedError + raise NotImplementedError + + def enable_enhanced_monitoring( + self, context: RequestContext, stream_name: StreamName, shard_level_metrics: MetricsNameList + ) -> EnhancedMonitoringOutput: + if config.KINESIS_PROVIDER == "kinesalite": + stream_metrics = KinesisBackend.get().enhanced_metrics[stream_name] + stream_metrics.update(shard_level_metrics) + stream_metrics_list = list(stream_metrics) + return EnhancedMonitoringOutput( + StreamName=stream_name, + CurrentShardLevelMetrics=stream_metrics_list, + DesiredShardLevelMetrics=stream_metrics_list, + ) + + # If kinesis-mock is used, we forward the request through the fallback by raising a NotImplementedError + raise NotImplementedError + + def disable_enhanced_monitoring( + self, context: RequestContext, stream_name: StreamName, shard_level_metrics: MetricsNameList + ) -> EnhancedMonitoringOutput: + if config.KINESIS_PROVIDER == "kinesalite": + region = KinesisBackend.get() + region.enhanced_metrics[stream_name] = region.enhanced_metrics[stream_name] - set( + shard_level_metrics + ) + stream_metrics_list = list(region.enhanced_metrics[stream_name]) + return EnhancedMonitoringOutput( + StreamName=stream_name, + CurrentShardLevelMetrics=stream_metrics_list, + DesiredShardLevelMetrics=stream_metrics_list, + ) + + # If kinesis-mock is used, we forward the request through the fallback by raising a NotImplementedError + raise NotImplementedError + + def update_shard_count( + self, + context: RequestContext, + stream_name: StreamName, + target_shard_count: PositiveIntegerObject, + scaling_type: ScalingType, + ) -> UpdateShardCountOutput: + if config.KINESIS_PROVIDER == "kinesalite": + # Currently, kinesalite - which backs the Kinesis implementation for localstack - does + # not support UpdateShardCount: https://github.com/mhart/kinesalite/issues/61 + # Terraform makes the call to UpdateShardCount when it + # applies Kinesis resources. A Terraform run fails when this is not present. + # This code just returns a successful response, bypassing the 400 response that kinesalite would return. + return UpdateShardCountOutput( + CurrentShardCount=1, + StreamName=stream_name, + TargetShardCount=target_shard_count, + ) + + # If kinesis-mock is used, we forward the request through the fallback by raising a NotImplementedError + raise NotImplementedError + + def create_stream( + self, + context: RequestContext, + stream_name: StreamName, + shard_count: PositiveIntegerObject = None, + stream_mode_details: StreamModeDetails = None, + ) -> None: + payload = {"n": event_publisher.get_hash(stream_name), "s": shard_count} + event_publisher.fire_event(event_publisher.EVENT_KINESIS_CREATE_STREAM, payload=payload) + + # After the event is logged, the request is forwarded to the fallback by raising a NotImplementedError + raise NotImplementedError + + def delete_stream( + self, + context: RequestContext, + stream_name: StreamName, + enforce_consumer_deletion: BooleanObject = None, + ) -> None: + payload = {"n": event_publisher.get_hash(stream_name)} + event_publisher.fire_event(event_publisher.EVENT_KINESIS_DELETE_STREAM, payload=payload) + + # After the event is logged, the request is forwarded to the fallback by raising a NotImplementedError + raise NotImplementedError diff --git a/localstack/services/providers.py b/localstack/services/providers.py index d1757019e67a6..768b0064d78c1 100644 --- a/localstack/services/providers.py +++ b/localstack/services/providers.py @@ -1,5 +1,6 @@ from localstack import config from localstack.aws.proxy import AwsApiListener +from localstack.services.kinesis.provider import KinesisApiListener from localstack.services.moto import MotoFallbackDispatcher from localstack.services.plugins import Service, aws_provider @@ -147,8 +148,8 @@ def sts(): return Service("sts", listener=listener) -@aws_provider() -def kinesis(): +@aws_provider(api="kinesis", name="legacy") +def kinesis_legacy(): from localstack.services.kinesis import kinesis_listener, kinesis_starter return Service( @@ -159,6 +160,17 @@ def kinesis(): ) +@aws_provider(api="kinesis", name="default") +def kinesis_asf(): + # from localstack.services.kinesis import kinesis_listener, kinesis_starter + # return Service("kinesis", listener=kinesis_listener.UPDATE_KINESIS, start=kinesis_starter.start_kinesis, check=kinesis_starter.check_kinesis,) + listener = KinesisApiListener() + return Service( + "kinesis", + listener=listener, + lifecycle_hook=listener.provider, + ) + @aws_provider() def kms(): if config.KMS_PROVIDER == "local-kms": From 827ecf33a2d0bed79162ee21c1a5656d45c6a76c Mon Sep 17 00:00:00 2001 From: Thomas Rausch Date: Wed, 27 Jul 2022 14:24:36 +0200 Subject: [PATCH 02/16] fix linting errors --- localstack/services/providers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/localstack/services/providers.py b/localstack/services/providers.py index 768b0064d78c1..e5c1fb95de5b6 100644 --- a/localstack/services/providers.py +++ b/localstack/services/providers.py @@ -171,6 +171,7 @@ def kinesis_asf(): lifecycle_hook=listener.provider, ) + @aws_provider() def kms(): if config.KMS_PROVIDER == "local-kms": From 3013a517a35a0cd26b058136ccd190e75b570f07 Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Tue, 16 Aug 2022 16:33:36 +0200 Subject: [PATCH 03/16] update ASF API for Kinesis --- localstack/aws/api/kinesis/__init__.py | 60 +++++++++++++++++++------- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/localstack/aws/api/kinesis/__init__.py b/localstack/aws/api/kinesis/__init__.py index 75ad27d107783..e5b446f548436 100644 --- a/localstack/aws/api/kinesis/__init__.py +++ b/localstack/aws/api/kinesis/__init__.py @@ -95,63 +95,93 @@ class StreamStatus(str): class ExpiredIteratorException(ServiceException): - message: Optional[ErrorMessage] + code: str = "ExpiredIteratorException" + sender_fault: bool = False + status_code: int = 400 class ExpiredNextTokenException(ServiceException): - message: Optional[ErrorMessage] + code: str = "ExpiredNextTokenException" + sender_fault: bool = False + status_code: int = 400 class InternalFailureException(ServiceException): - message: Optional[ErrorMessage] + code: str = "InternalFailureException" + sender_fault: bool = False + status_code: int = 400 class InvalidArgumentException(ServiceException): - message: Optional[ErrorMessage] + code: str = "InvalidArgumentException" + sender_fault: bool = False + status_code: int = 400 class KMSAccessDeniedException(ServiceException): - message: Optional[ErrorMessage] + code: str = "KMSAccessDeniedException" + sender_fault: bool = False + status_code: int = 400 class KMSDisabledException(ServiceException): - message: Optional[ErrorMessage] + code: str = "KMSDisabledException" + sender_fault: bool = False + status_code: int = 400 class KMSInvalidStateException(ServiceException): - message: Optional[ErrorMessage] + code: str = "KMSInvalidStateException" + sender_fault: bool = False + status_code: int = 400 class KMSNotFoundException(ServiceException): - message: Optional[ErrorMessage] + code: str = "KMSNotFoundException" + sender_fault: bool = False + status_code: int = 400 class KMSOptInRequired(ServiceException): - message: Optional[ErrorMessage] + code: str = "KMSOptInRequired" + sender_fault: bool = False + status_code: int = 400 class KMSThrottlingException(ServiceException): - message: Optional[ErrorMessage] + code: str = "KMSThrottlingException" + sender_fault: bool = False + status_code: int = 400 class LimitExceededException(ServiceException): - message: Optional[ErrorMessage] + code: str = "LimitExceededException" + sender_fault: bool = False + status_code: int = 400 class ProvisionedThroughputExceededException(ServiceException): - message: Optional[ErrorMessage] + code: str = "ProvisionedThroughputExceededException" + sender_fault: bool = False + status_code: int = 400 class ResourceInUseException(ServiceException): - message: Optional[ErrorMessage] + code: str = "ResourceInUseException" + sender_fault: bool = False + status_code: int = 400 class ResourceNotFoundException(ServiceException): - message: Optional[ErrorMessage] + code: str = "ResourceNotFoundException" + sender_fault: bool = False + status_code: int = 400 class ValidationException(ServiceException): - message: Optional[ErrorMessage] + code: str = "ValidationException" + sender_fault: bool = False + status_code: int = 400 TagMap = Dict[TagKey, TagValue] From a91313e88564812f5daed2083a8ef7bfc666398e Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Tue, 16 Aug 2022 16:36:09 +0200 Subject: [PATCH 04/16] remove old event logging code --- localstack/services/kinesis/provider.py | 28 ------------------------- 1 file changed, 28 deletions(-) diff --git a/localstack/services/kinesis/provider.py b/localstack/services/kinesis/provider.py index 54946de3ef2c2..e070ece65fc5f 100644 --- a/localstack/services/kinesis/provider.py +++ b/localstack/services/kinesis/provider.py @@ -8,7 +8,6 @@ from localstack import config from localstack.aws.api import RequestContext from localstack.aws.api.kinesis import ( - BooleanObject, Consumer, ConsumerARN, ConsumerDescription, @@ -38,7 +37,6 @@ ShardId, StartingPosition, StreamARN, - StreamModeDetails, StreamName, SubscribeToShardEvent, SubscribeToShardEventStream, @@ -52,7 +50,6 @@ from localstack.services.generic_proxy import RegionBackend from localstack.services.kinesis.kinesis_starter import check_kinesis, start_kinesis from localstack.services.plugins import ServiceLifecycleHook -from localstack.utils.analytics import event_publisher from localstack.utils.aws import aws_stack LOG = logging.getLogger(__name__) @@ -347,28 +344,3 @@ def update_shard_count( # If kinesis-mock is used, we forward the request through the fallback by raising a NotImplementedError raise NotImplementedError - - def create_stream( - self, - context: RequestContext, - stream_name: StreamName, - shard_count: PositiveIntegerObject = None, - stream_mode_details: StreamModeDetails = None, - ) -> None: - payload = {"n": event_publisher.get_hash(stream_name), "s": shard_count} - event_publisher.fire_event(event_publisher.EVENT_KINESIS_CREATE_STREAM, payload=payload) - - # After the event is logged, the request is forwarded to the fallback by raising a NotImplementedError - raise NotImplementedError - - def delete_stream( - self, - context: RequestContext, - stream_name: StreamName, - enforce_consumer_deletion: BooleanObject = None, - ) -> None: - payload = {"n": event_publisher.get_hash(stream_name)} - event_publisher.fire_event(event_publisher.EVENT_KINESIS_DELETE_STREAM, payload=payload) - - # After the event is logged, the request is forwarded to the fallback by raising a NotImplementedError - raise NotImplementedError From 3538237a25d6cdf39beaa694e47b49677907088a Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Thu, 18 Aug 2022 13:34:11 +0200 Subject: [PATCH 05/16] avoid extending AwsApiListener --- localstack/services/kinesis/provider.py | 9 --------- localstack/services/providers.py | 15 ++++++++------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/localstack/services/kinesis/provider.py b/localstack/services/kinesis/provider.py index e070ece65fc5f..0e2b684c4ee94 100644 --- a/localstack/services/kinesis/provider.py +++ b/localstack/services/kinesis/provider.py @@ -44,8 +44,6 @@ Timestamp, UpdateShardCountOutput, ) -from localstack.aws.forwarder import HttpFallbackDispatcher -from localstack.aws.proxy import AwsApiListener from localstack.constants import LOCALHOST from localstack.services.generic_proxy import RegionBackend from localstack.services.kinesis.kinesis_starter import check_kinesis, start_kinesis @@ -64,13 +62,6 @@ # - Convert base64 encoded data back to bytes for the cbor encoding -class KinesisApiListener(AwsApiListener): - def __init__(self, provider=None): - provider = provider or KinesisProvider() - self.provider = provider - super().__init__("kinesis", HttpFallbackDispatcher(provider, provider.get_forward_url)) - - class KinesisBackend(RegionBackend): def __init__(self): # list of stream consumer details diff --git a/localstack/services/providers.py b/localstack/services/providers.py index e5c1fb95de5b6..290f1d1493d37 100644 --- a/localstack/services/providers.py +++ b/localstack/services/providers.py @@ -1,6 +1,6 @@ from localstack import config +from localstack.aws.forwarder import HttpFallbackDispatcher from localstack.aws.proxy import AwsApiListener -from localstack.services.kinesis.provider import KinesisApiListener from localstack.services.moto import MotoFallbackDispatcher from localstack.services.plugins import Service, aws_provider @@ -160,15 +160,16 @@ def kinesis_legacy(): ) -@aws_provider(api="kinesis", name="default") -def kinesis_asf(): - # from localstack.services.kinesis import kinesis_listener, kinesis_starter - # return Service("kinesis", listener=kinesis_listener.UPDATE_KINESIS, start=kinesis_starter.start_kinesis, check=kinesis_starter.check_kinesis,) - listener = KinesisApiListener() +@aws_provider() +def kinesis(): + from localstack.services.kinesis.provider import KinesisProvider + + provider = KinesisProvider() + listener = AwsApiListener("kinesis", HttpFallbackDispatcher(provider, provider.get_forward_url)) return Service( "kinesis", listener=listener, - lifecycle_hook=listener.provider, + lifecycle_hook=provider, ) From 187b3f83605597fc229b93bf47d9c05f6b24701b Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Fri, 19 Aug 2022 11:07:27 +0200 Subject: [PATCH 06/16] add content-negotiation support to ASF serializer --- localstack/aws/forwarder.py | 7 +- localstack/aws/handlers/service.py | 4 +- localstack/aws/protocol/serializer.py | 343 +++++++++++++++------ localstack/aws/skeleton.py | 8 +- localstack/constants.py | 1 + localstack/services/kinesis/provider.py | 9 - localstack/services/moto.py | 53 ++-- localstack/services/providers.py | 6 +- localstack/services/sqs/query_api.py | 5 +- localstack/services/sts/provider.py | 40 +-- tests/integration/apigateway_fixtures.py | 6 +- tests/integration/test_apigateway.py | 2 +- tests/integration/test_apigateway_api.py | 6 + tests/integration/test_moto.py | 45 ++- tests/integration/test_s3control.py | 2 +- tests/unit/aws/handlers/service.py | 8 +- tests/unit/aws/protocol/test_serializer.py | 34 +- 17 files changed, 346 insertions(+), 233 deletions(-) diff --git a/localstack/aws/forwarder.py b/localstack/aws/forwarder.py index 2d61bcf664020..5e6b03452efe0 100644 --- a/localstack/aws/forwarder.py +++ b/localstack/aws/forwarder.py @@ -74,7 +74,12 @@ def _forward_request(context, service_request: ServiceRequest = None) -> Service parameters=service_request, region=context.region, ) - local_context.request.headers.update(context.request.headers) + # update the newly created context with non-payload specific request headers (the payload can differ from + # the original request, f.e. it could be JSON encoded now while the initial request was CBOR encoded) + headers = Headers(context.request.headers) + headers.pop("Content-Type", None) + headers.pop("Content-Length", None) + local_context.request.headers.update(headers) context = local_context return forward_request(context, forward_url_getter) diff --git a/localstack/aws/handlers/service.py b/localstack/aws/handlers/service.py index fde24f0785445..cdebe4b6a4bf0 100644 --- a/localstack/aws/handlers/service.py +++ b/localstack/aws/handlers/service.py @@ -150,7 +150,7 @@ def create_not_implemented_response(self, context): message = f"no handler for operation '{operation_name}' on service '{service_name}'" error = CommonServiceException("InternalFailure", message, status_code=501) serializer = create_serializer(context.service) - return serializer.serialize_error_to_response(error, operation) + return serializer.serialize_error_to_response(error, operation, context.request.headers) class ServiceExceptionSerializer(ExceptionHandler): @@ -225,7 +225,7 @@ def create_exception_response(self, exception: Exception, context: RequestContex context.service_exception = error serializer = create_serializer(context.service) # TODO: serializer cache - return serializer.serialize_error_to_response(error, operation) + return serializer.serialize_error_to_response(error, operation, context.request.headers) class ServiceResponseParser(Handler): diff --git a/localstack/aws/protocol/serializer.py b/localstack/aws/protocol/serializer.py index 26061ef7584e3..33974576e8b9f 100644 --- a/localstack/aws/protocol/serializer.py +++ b/localstack/aws/protocol/serializer.py @@ -80,17 +80,31 @@ from datetime import datetime from email.utils import formatdate from struct import pack -from typing import Any, Iterable, Iterator, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from xml.etree import ElementTree as ETree +import cbor2 +import xmltodict from boto.utils import ISO8601 from botocore.model import ListShape, MapShape, OperationModel, ServiceModel, Shape, StructureShape from botocore.serialize import ISO8601_MICRO from botocore.utils import calculate_md5, is_json_value_header, parse_to_aware_datetime from moto.core.utils import gen_amzn_requestid_long +from werkzeug.datastructures import Headers, MIMEAccept +from werkzeug.http import parse_accept_header from localstack.aws.api import HttpResponse, ServiceException +from localstack.constants import ( + APPLICATION_AMZ_CBOR_1_1, + APPLICATION_AMZ_JSON_1_0, + APPLICATION_AMZ_JSON_1_1, + APPLICATION_CBOR, + APPLICATION_JSON, + APPLICATION_XML, + TEXT_XML, +) from localstack.utils.common import to_bytes, to_str +from localstack.utils.xml import strip_xmlns LOG = logging.getLogger(__name__) @@ -106,7 +120,7 @@ class ResponseSerializerError(Exception): class UnknownSerializerError(ResponseSerializerError): """ - Error which indicates that the raised exception of the serializer could be caused by invalid data or by any other + Error which indicates that the exception raised by the serializer could be caused by invalid data or by any other (unknown) issue. Errors like this should be reported and indicate an issue in the serializer itself. """ @@ -115,8 +129,8 @@ class UnknownSerializerError(ResponseSerializerError): class ProtocolSerializerError(ResponseSerializerError): """ - Error which indicates that the given data is not compliant with the service's specification and cannot be serialized. - This usually results in a response to the client with an HTTP 5xx status code (internal server error). + Error which indicates that the given data is not compliant with the service's specification and cannot be + serialized. This usually results in a response to the client with an HTTP 5xx status code (internal server error). """ pass @@ -156,10 +170,13 @@ class ResponseSerializer(abc.ABC): TIMESTAMP_FORMAT = "iso8601" # Event streaming binary data type mapping for type "string" AWS_BINARY_DATA_TYPE_STRING = 7 + # Defines the supported mime types of the specific serializer. Sorted by priority (preferred / default first). + # Needs to be specified by subclasses. + SUPPORTED_MIME_TYPES: List[str] = [] @_handle_exceptions def serialize_to_response( - self, response: dict, operation_model: OperationModel + self, response: dict, operation_model: OperationModel, headers: Optional[Dict | Headers] ) -> HttpResponse: """ Takes a response dict and serializes it to an actual HttpResponse. @@ -167,19 +184,25 @@ def serialize_to_response( :param response: to serialize :param operation_model: specification of the service & operation containing information about the shape of the service's output / response + :param headers: the headers of the incoming request this response should be serialized for. This is necessary + for features like Content-Negotiation (define response content type based on request headers). :return: HttpResponse which can be sent to the calling client :raises: ResponseSerializerError (either a ProtocolSerializerError or an UnknownSerializerError) """ + + # determine the preferred mime type (based on the serializer's supported mime types and the Accept header) + mime_type = self._get_mime_type(headers) + # if the operation has a streaming output, handle the serialization differently if operation_model.has_event_stream_output: - return self._serialize_event_stream(response, operation_model) + return self._serialize_event_stream(response, operation_model, mime_type) - serialized_response = self._create_default_response(operation_model) + serialized_response = self._create_default_response(operation_model, mime_type) shape = operation_model.output_shape # The shape can also be none (for empty responses), but it still needs to be serialized (to add some metadata) shape_members = shape.members if shape is not None else None self._serialize_response( - response, serialized_response, shape, shape_members, operation_model + response, serialized_response, shape, shape_members, operation_model, mime_type ) serialized_response = self._prepare_additional_traits_in_response( serialized_response, operation_model @@ -188,20 +211,28 @@ def serialize_to_response( @_handle_exceptions def serialize_error_to_response( - self, error: ServiceException, operation_model: OperationModel + self, + error: ServiceException, + operation_model: OperationModel, + headers: Optional[Dict | Headers], ) -> HttpResponse: """ Takes an error instance and serializes it to an actual HttpResponse. - Therefore this method is used for errors which should be serialized and transmitted to the calling client. + Therefore, this method is used for errors which should be serialized and transmitted to the calling client. :param error: to serialize :param operation_model: specification of the service & operation containing information about the shape of the service's output / response + :param headers: the headers of the incoming request this response should be serialized for. This is necessary + for features like Content-Negotiation (define response content type based on request headers). :return: HttpResponse which can be sent to the calling client :raises: ResponseSerializerError (either a ProtocolSerializerError or an UnknownSerializerError) """ + # determine the preferred mime type (based on the serializer's supported mime types and the Accept header) + mime_type = self._get_mime_type(headers) + # TODO implement streaming error serialization - serialized_response = self._create_default_response(operation_model) + serialized_response = self._create_default_response(operation_model, mime_type) if not error or not isinstance(error, ServiceException): raise ProtocolSerializerError( f"Error to serialize ({error.__class__.__name__ if error else None}) is not a ServiceException." @@ -209,7 +240,7 @@ def serialize_error_to_response( shape = operation_model.service_model.shape_for_error_code(error.code) serialized_response.status_code = error.status_code - self._serialize_error(error, serialized_response, shape, operation_model) + self._serialize_error(error, serialized_response, shape, operation_model, mime_type) serialized_response = self._prepare_additional_traits_in_response( serialized_response, operation_model ) @@ -222,11 +253,16 @@ def _serialize_response( shape: Optional[Shape], shape_members: dict, operation_model: OperationModel, + mime_type: str, ) -> None: raise NotImplementedError def _serialize_body_params( - self, params: dict, shape: Shape, operation_model: OperationModel + self, + params: dict, + shape: Shape, + operation_model: OperationModel, + mime_type: str, ) -> Optional[str]: """ Actually serializes the given params for the given shape to a string for the transmission in the body of the @@ -234,6 +270,7 @@ def _serialize_body_params( :param params: to serialize :param shape: to know how to serialize the params :param operation_model: for additional metadata + :param mime_type: Mime type which should be used to encode the payload :return: string containing the serialized body """ raise NotImplementedError @@ -244,11 +281,12 @@ def _serialize_error( response: HttpResponse, shape: StructureShape, operation_model: OperationModel, + mime_type: str, ) -> None: raise NotImplementedError def _serialize_event_stream( - self, response: dict, operation_model: OperationModel + self, response: dict, operation_model: OperationModel, mime_type: str ) -> HttpResponse: """ Serializes a given response dict (the return payload of a service implementation) to an _event stream_ using the @@ -256,6 +294,7 @@ def _serialize_event_stream( :param response: dictionary containing the payload for the response :param operation_model: describing the operation the response dict is being returned by + :param mime_type: Mime type which should be used to encode the payload :return: HttpResponse which can directly be sent to the client (in chunks) """ event_stream_shape = operation_model.get_event_stream_output() @@ -267,7 +306,7 @@ def event_stream_serializer() -> Iterable[bytes]: # yield convert_to_binary_event_payload("", event_type="initial-response") # create a default response - serialized_event_response = self._create_default_response(operation_model) + serialized_event_response = self._create_default_response(operation_model, mime_type) # get the members of the event stream shape event_stream_shape_members = ( event_stream_shape.members if event_stream_shape is not None else None @@ -297,6 +336,7 @@ def event_stream_serializer() -> Iterable[bytes]: event_member, event_member.members if event_member is not None else None, operation_model, + mime_type, ) # execute additional response traits (might be modifying the response) serialized_event_response = self._prepare_additional_traits_in_response( @@ -383,15 +423,49 @@ def _encode_event_payload( return result - def _create_default_response(self, operation_model: OperationModel) -> HttpResponse: + def _create_default_response( + self, operation_model: OperationModel, mime_type: str + ) -> HttpResponse: """ Creates a boilerplate default response to be used by subclasses as starting points. - Uses the default HTTP response status code defined in the operation model (if defined). + Uses the default HTTP response status code defined in the operation model (if defined), otherwise 200. :param operation_model: to extract the default HTTP status code + :param mime_type: Mime type which should be used to encode the payload :return: boilerplate HTTP response """ - return HttpResponse(response=b"", status=operation_model.http.get("responseCode", 200)) + return HttpResponse(status=operation_model.http.get("responseCode", 200)) + + def _get_mime_type(self, headers: Optional[Dict | Headers]) -> str: + """ + Extracts the accepted mime type from the request headers and returns a matching, supported mime type for the + serializer. + :param headers: to extract the "Accept" header from + :return: preferred mime type to be used by the serializer (if it is not accepted by the client, + an error is logged) + """ + accept_header = None + if headers and "Accept" in headers and not headers.get("Accept") == "*/*": + accept_header = headers.get("Accept") + elif headers and headers.get("Content-Type"): + # If there is no specific Accept header given, we use the given Content-Type as a fallback. + # i.e. if the request content was JSON encoded and the client doesn't send a specific an Accept header, the + # serializer should prefer JSON encoding. + content_type = headers.get("Content-Type") + LOG.debug( + "No accept header given. Using request's Content-Type (%s) as preferred response Content-Type.", + content_type, + ) + accept_header = content_type + ", */*" + mime_accept: MIMEAccept = parse_accept_header(accept_header, MIMEAccept) + mime_type = mime_accept.best_match(self.SUPPORTED_MIME_TYPES) + if not mime_type: + # There is no match between the supported mime types and the requested one(s) + LOG.debug( + "Determined accept type (%s) is not supported by this serializer.", accept_header + ) + mime_type = self.SUPPORTED_MIME_TYPES[0] + return mime_type # Some extra utility methods subclasses can use. @@ -427,7 +501,7 @@ def _convert_timestamp_to_str( def _get_serialized_name(shape: Shape, default_name: str) -> str: """ Returns the serialized name for the shape if it exists. - Otherwise it will return the passed in default_name. + Otherwise, it will return the passed in default_name. """ return shape.serialization.get("name", default_name) @@ -482,12 +556,15 @@ class BaseXMLResponseSerializer(ResponseSerializer): service to the client). """ + SUPPORTED_MIME_TYPES = [TEXT_XML, APPLICATION_XML, APPLICATION_JSON] + def _serialize_error( self, error: ServiceException, response: HttpResponse, shape: StructureShape, operation_model: OperationModel, + mime_type: str, ) -> None: # Check if we need to add a namespace attr = ( @@ -498,26 +575,28 @@ def _serialize_error( root = ETree.Element("ErrorResponse", attr) error_tag = ETree.SubElement(root, "Error") - self._add_error_tags(error, error_tag) + self._add_error_tags(error, error_tag, mime_type) request_id = ETree.SubElement(root, "RequestId") request_id.text = gen_amzn_requestid_long() - self._add_additional_error_tags(error, root, shape) + self._add_additional_error_tags(error, root, shape, mime_type) - response.set_response(self._encode_payload(self._xml_to_string(root))) + response.set_response(self._encode_payload(self._node_to_string(root, mime_type))) - def _add_error_tags(self, error: ServiceException, error_tag: ETree.Element) -> None: + def _add_error_tags( + self, error: ServiceException, error_tag: ETree.Element, mime_type: str + ) -> None: code_tag = ETree.SubElement(error_tag, "Code") code_tag.text = error.code message = self._get_error_message(error) if message: - self._default_serialize(error_tag, message, None, "Message") + self._default_serialize(error_tag, message, None, "Message", mime_type) if error.sender_fault: # The sender fault is either not set or "Sender" - self._default_serialize(error_tag, "Sender", None, "Type") + self._default_serialize(error_tag, "Sender", None, "Type", mime_type) def _add_additional_error_tags( - self, error: ServiceException, node: ETree, shape: StructureShape + self, error: ServiceException, node: ETree, shape: StructureShape, mime_type: str ): if shape: params = {} @@ -533,32 +612,38 @@ def _add_additional_error_tags( # Serialize the remaining params root_name = shape.serialization.get("name", shape.name) pseudo_root = ETree.Element("") - self._serialize(shape, params, pseudo_root, root_name) + self._serialize(shape, params, pseudo_root, root_name, mime_type) real_root = list(pseudo_root)[0] # Add the child elements to the already created root error element for child in list(real_root): node.append(child) def _serialize_body_params( - self, params: dict, shape: Shape, operation_model: OperationModel + self, + params: dict, + shape: Shape, + operation_model: OperationModel, + mime_type: str, ) -> Optional[str]: - root = self._serialize_body_params_to_xml(params, shape, operation_model) + root = self._serialize_body_params_to_xml(params, shape, operation_model, mime_type) self._prepare_additional_traits_in_xml(root) - return self._xml_to_string(root) + return self._node_to_string(root, mime_type) def _serialize_body_params_to_xml( - self, params: dict, shape: Shape, operation_model: OperationModel + self, params: dict, shape: Shape, operation_model: OperationModel, mime_type: str ) -> Optional[ETree.Element]: if shape is None: return # The botocore serializer expects `shape.serialization["name"]`, but this isn't always present for responses root_name = shape.serialization.get("name", shape.name) pseudo_root = ETree.Element("") - self._serialize(shape, params, pseudo_root, root_name) + self._serialize(shape, params, pseudo_root, root_name, mime_type) real_root = list(pseudo_root)[0] return real_root - def _serialize(self, shape: Shape, params: Any, xmlnode: ETree.Element, name: str) -> None: + def _serialize( + self, shape: Shape, params: Any, xmlnode: ETree.Element, name: str, mime_type: str + ) -> None: """This method dynamically invokes the correct `_serialize_type_*` method for each shape type.""" if shape is None: return @@ -569,14 +654,14 @@ def _serialize(self, shape: Shape, params: Any, xmlnode: ETree.Element, name: st try: method = getattr(self, "_serialize_type_%s" % shape.type_name, self._default_serialize) - method(xmlnode, params, shape, name) + method(xmlnode, params, shape, name, mime_type) except (TypeError, ValueError, AttributeError) as e: raise ProtocolSerializerError( f"Invalid type when serializing {shape.name}: '{xmlnode}' cannot be parsed to {shape.type_name}." ) from e def _serialize_type_structure( - self, xmlnode: ETree.Element, params: dict, shape: StructureShape, name: str + self, xmlnode: ETree.Element, params: dict, shape: StructureShape, name: str, mime_type ) -> None: structure_node = ETree.SubElement(xmlnode, name) @@ -608,10 +693,10 @@ def _serialize_type_structure( xml_attribute_name = member_shape.serialization["name"] structure_node.attrib[xml_attribute_name] = value continue - self._serialize(member_shape, value, structure_node, member_name) + self._serialize(member_shape, value, structure_node, member_name, mime_type) def _serialize_type_list( - self, xmlnode: ETree.Element, params: list, shape: ListShape, name: str + self, xmlnode: ETree.Element, params: list, shape: ListShape, name: str, mime_type: str ) -> None: if params is None: # Don't serialize any param whose value is None. @@ -628,10 +713,10 @@ def _serialize_type_list( for item in params: # Don't serialize any item which is None if item is not None: - self._serialize(member_shape, item, list_node, element_name) + self._serialize(member_shape, item, list_node, element_name, mime_type) def _serialize_type_map( - self, xmlnode: ETree.Element, params: dict, shape: MapShape, name: str + self, xmlnode: ETree.Element, params: dict, shape: MapShape, name: str, mime_type: str ) -> None: """ Given the ``name`` of MyMap, an input of {"key1": "val1", "key2": "val2"}, and the ``flattened: False`` @@ -673,11 +758,11 @@ def _serialize_type_map( entry_node = ETree.SubElement(entries_node, entry_node_name) key_name = self._get_serialized_name(shape.key, default_name="key") val_name = self._get_serialized_name(shape.value, default_name="value") - self._serialize(shape.key, key, entry_node, key_name) - self._serialize(shape.value, value, entry_node, val_name) + self._serialize(shape.key, key, entry_node, key_name, mime_type) + self._serialize(shape.value, value, entry_node, val_name, mime_type) @staticmethod - def _serialize_type_boolean(xmlnode: ETree.Element, params: bool, _, name: str) -> None: + def _serialize_type_boolean(xmlnode: ETree.Element, params: bool, _, name: str, __) -> None: """ For scalar types, the 'params' attr is actually just a scalar value representing the data we need to serialize as a boolean. It will either be 'true' or 'false' @@ -690,20 +775,29 @@ def _serialize_type_boolean(xmlnode: ETree.Element, params: bool, _, name: str) node.text = str_value def _serialize_type_blob( - self, xmlnode: ETree.Element, params: Union[str, bytes], _, name: str + self, xmlnode: ETree.Element, params: Union[str, bytes], _, name: str, __ ) -> None: node = ETree.SubElement(xmlnode, name) node.text = self._get_base64(params) def _serialize_type_timestamp( - self, xmlnode: ETree.Element, params: str, shape: Shape, name: str + self, xmlnode: ETree.Element, params: str, shape: Shape, name: str, mime_type: str ) -> None: node = ETree.SubElement(xmlnode, name) - node.text = self._convert_timestamp_to_str( - params, shape.serialization.get("timestampFormat") - ) + if mime_type != APPLICATION_JSON: + # Default XML timestamp serialization + node.text = self._convert_timestamp_to_str( + params, shape.serialization.get("timestampFormat") + ) + else: + # For services with XML protocols, where the Accept header is JSON, timestamps are formatted like for JSON + # protocols, but using the int representation instead of the float representation (f.e. requesting JSON + # responses in STS). + node.text = str( + int(self._convert_timestamp_to_str(params, JSONResponseSerializer.TIMESTAMP_FORMAT)) + ) - def _default_serialize(self, xmlnode: ETree.Element, params: str, _, name: str) -> None: + def _default_serialize(self, xmlnode: ETree.Element, params: str, _, name: str, __) -> None: node = ETree.SubElement(xmlnode, name) node.text = str(params) @@ -715,17 +809,25 @@ def _prepare_additional_traits_in_xml(self, root: Optional[ETree.Element]): """ pass - def _create_default_response(self, operation_model: OperationModel) -> HttpResponse: - response = super()._create_default_response(operation_model) - response.headers["Content-Type"] = "text/xml" + def _create_default_response( + self, operation_model: OperationModel, mime_type: str + ) -> HttpResponse: + response = super()._create_default_response(operation_model, mime_type) + response.headers["Content-Type"] = mime_type return response - def _xml_to_string(self, root: Optional[ETree.Element]) -> Optional[str]: + def _node_to_string(self, root: Optional[ETree.Element], mime_type: str) -> Optional[str]: """Generates the string representation of the given XML element.""" if root is not None: - return ETree.tostring( + content = ETree.tostring( element=root, encoding=self.DEFAULT_ENCODING, xml_declaration=True ) + if mime_type == APPLICATION_JSON: + # FIXME try to directly convert the ElementTree node to JSON + xml_dict = xmltodict.parse(content) + xml_dict = strip_xmlns(xml_dict) + content = json.dumps(xml_dict) + return content class BaseRestResponseSerializer(ResponseSerializer, ABC): @@ -743,14 +845,17 @@ def _serialize_response( shape: Optional[Shape], shape_members: dict, operation_model: OperationModel, + mime_type: str, ) -> None: header_params, payload_params = self._partition_members(parameters, shape) self._process_header_members(header_params, response, shape) # "HEAD" responses are basically "GET" responses without the actual body. # Do not process the body payload in this case (setting a body could also manipulate the headers) if operation_model.http.get("method") != "HEAD": - self._serialize_payload(payload_params, response, shape, shape_members, operation_model) - self._serialize_content_type(response, shape, shape_members) + self._serialize_payload( + payload_params, response, shape, shape_members, operation_model, mime_type + ) + self._serialize_content_type(response, shape, shape_members, mime_type) self._prepare_additional_traits_in_response(response, operation_model) def _serialize_payload( @@ -760,6 +865,7 @@ def _serialize_payload( shape: Optional[Shape], shape_members: dict, operation_model: OperationModel, + mime_type: str, ) -> None: """ Serializes the given payload. @@ -769,6 +875,7 @@ def _serialize_payload( :param shape: Describes the expected output shape (can be None in case of an "empty" response) :param shape_members: The members of the output struct shape :param operation_model: The specification of the operation of which the response is serialized here + :param mime_type: Mime type which should be used to encode the payload :return: None - the given `serialized` dict is modified """ if shape is None: @@ -790,7 +897,7 @@ def _serialize_payload( response.set_response( self._encode_payload( self._serialize_body_params( - body_params, shape_members[payload_member], operation_model + body_params, shape_members[payload_member], operation_model, mime_type ) ) ) @@ -798,15 +905,16 @@ def _serialize_payload( # Otherwise, we use the "traditional" way of serializing the whole parameters dict recursively. response.set_response( self._encode_payload( - self._serialize_body_params(parameters, shape, operation_model) + self._serialize_body_params(parameters, shape, operation_model, mime_type) ) ) - def _serialize_content_type(self, serialized: HttpResponse, shape: Shape, shape_members: dict): + def _serialize_content_type( + self, serialized: HttpResponse, shape: Shape, shape_members: dict, mime_type: str + ): """ - Some protocols require varied Content-Type headers - depending on user input. This allows subclasses to apply - this conditionally. + Some protocols require varied Content-Type headers depending on user input. + This allows subclasses to apply this conditionally. """ pass @@ -877,7 +985,7 @@ def _partition_members(self, parameters: dict, shape: Optional[Shape]) -> Tuple[ """Separates the top-level keys in the given parameters dict into header- and payload-located params.""" if not isinstance(shape, StructureShape): # If the shape isn't a structure, we default to the whole response being parsed in the body. - # Non-payload members are only loated in the top-level hierarchy and those are always structures. + # Non-payload members are only loaded in the top-level hierarchy and those are always structures. return {}, parameters header_params = {} payload_params = {} @@ -919,6 +1027,7 @@ def _serialize_response( shape: Optional[Shape], shape_members: dict, operation_model: OperationModel, + mime_type: str, ) -> None: """ Serializes the given parameters as XML for the query protocol. @@ -928,19 +1037,22 @@ def _serialize_response( :param shape: Describes the expected output shape (can be None in case of an "empty" response) :param shape_members: The members of the output struct shape :param operation_model: The specification of the operation of which the response is serialized here + :param mime_type: Mime type which should be used to encode the payload :return: None - the given `serialized` dict is modified """ response.set_response( - self._encode_payload(self._serialize_body_params(parameters, shape, operation_model)) + self._encode_payload( + self._serialize_body_params(parameters, shape, operation_model, mime_type) + ) ) def _serialize_body_params_to_xml( - self, params: dict, shape: Shape, operation_model: OperationModel + self, params: dict, shape: Shape, operation_model: OperationModel, mime_type: str ) -> ETree.Element: # The Query protocol responses have a root element which is not contained in the specification file. # Therefore, we first call the super function to perform the normal XML serialization, and afterwards wrap the # result in a root element based on the operation name. - node = super()._serialize_body_params_to_xml(params, shape, operation_model) + node = super()._serialize_body_params_to_xml(params, shape, operation_model, mime_type) # Check if we need to add a namespace attr = ( @@ -976,6 +1088,7 @@ def _serialize_error( response: HttpResponse, shape: StructureShape, operation_model: OperationModel, + mime_type: str, ) -> None: # EC2 errors look like: # @@ -997,10 +1110,10 @@ def _serialize_error( root = ETree.Element("Response", attr) errors_tag = ETree.SubElement(root, "Errors") error_tag = ETree.SubElement(errors_tag, "Error") - self._add_error_tags(error, error_tag) + self._add_error_tags(error, error_tag, mime_type) request_id = ETree.SubElement(root, "RequestID") request_id.text = gen_amzn_requestid_long() - response.set_response(self._encode_payload(self._xml_to_string(root))) + response.set_response(self._encode_payload(self._node_to_string(root, mime_type))) def _prepare_additional_traits_in_xml(self, root: Optional[ETree.Element]): # The EC2 protocol does not use the root output shape, therefore we need to remove the hierarchy level @@ -1024,6 +1137,10 @@ class JSONResponseSerializer(ResponseSerializer): ``RestJSONResponseSerializer``. """ + JSON_TYPES = [APPLICATION_JSON, APPLICATION_AMZ_JSON_1_0, APPLICATION_AMZ_JSON_1_1] + CBOR_TYPES = [APPLICATION_CBOR, APPLICATION_AMZ_CBOR_1_1] + SUPPORTED_MIME_TYPES = JSON_TYPES + CBOR_TYPES + TIMESTAMP_FORMAT = "unixtimestamp" def _serialize_error( @@ -1032,8 +1149,9 @@ def _serialize_error( response: HttpResponse, shape: StructureShape, operation_model: OperationModel, + mime_type: str, ) -> None: - body = {} + body = dict() # TODO implement different service-specific serializer configurations # - currently we set both, the `__type` member as well as the `X-Amzn-Errortype` header @@ -1050,7 +1168,7 @@ def _serialize_error( # Default error message fields can sometimes have different casing in the specs elif member.lower() in ["code", "message"] and hasattr(error, member.lower()): remaining_params[member] = getattr(error, member.lower()) - self._serialize(body, remaining_params, shape) + self._serialize(body, remaining_params, shape, None, mime_type) # Only set the message if it has not been set with the shape members if "message" not in body and "Message" not in body: @@ -1058,7 +1176,11 @@ def _serialize_error( if message is not None: body["message"] = message - response.set_json(body) + if mime_type in self.CBOR_TYPES: + response.set_response(cbor2.dumps(body)) + response.content_type = mime_type + else: + response.set_json(body) def _serialize_response( self, @@ -1067,31 +1189,43 @@ def _serialize_response( shape: Optional[Shape], shape_members: dict, operation_model: OperationModel, + mime_type: str, ) -> None: - json_version = operation_model.metadata.get("jsonVersion") - if json_version is not None: - response.headers["Content-Type"] = "application/x-amz-json-%s" % json_version - response.set_response(self._serialize_body_params(parameters, shape, operation_model)) + if mime_type in self.CBOR_TYPES: + response.content_type = mime_type + else: + json_version = operation_model.metadata.get("jsonVersion") + if json_version is not None: + response.headers["Content-Type"] = "application/x-amz-json-%s" % json_version + response.set_response( + self._serialize_body_params(parameters, shape, operation_model, mime_type) + ) def _serialize_body_params( - self, params: dict, shape: Shape, operation_model: OperationModel + self, params: dict, shape: Shape, operation_model: OperationModel, mime_type: str ) -> Optional[str]: body = {} if shape is not None: - self._serialize(body, params, shape) - return json.dumps(body) + self._serialize(body, params, shape, None, mime_type) - def _serialize(self, body: dict, value: Any, shape, key: Optional[str] = None): + if mime_type in self.CBOR_TYPES: + return cbor2.dumps(body) + else: + return json.dumps(body) + + def _serialize(self, body: dict, value: Any, shape, key: Optional[str], mime_type: str): """This method dynamically invokes the correct `_serialize_type_*` method for each shape type.""" try: method = getattr(self, "_serialize_type_%s" % shape.type_name, self._default_serialize) - method(body, value, shape, key) + method(body, value, shape, key, mime_type) except (TypeError, ValueError, AttributeError) as e: raise ProtocolSerializerError( f"Invalid type when serializing {shape.name}: '{value}' cannot be parsed to {shape.type_name}." ) from e - def _serialize_type_structure(self, body: dict, value: dict, shape: StructureShape, key: str): + def _serialize_type_structure( + self, body: dict, value: dict, shape: StructureShape, key: Optional[str], mime_type: str + ): if value is None: return if shape.is_document_type: @@ -1099,7 +1233,7 @@ def _serialize_type_structure(self, body: dict, value: dict, shape: StructureSha else: if key is not None: # If a key is provided, this is a result of a recursive - # call so we need to add a new child dict as the value + # call, so we need to add a new child dict as the value # of the passed in serialized dict. We'll then add # all the structure members as key/vals in the new serialized # dictionary we just created. @@ -1121,18 +1255,22 @@ def _serialize_type_structure(self, body: dict, value: dict, shape: StructureSha continue if "name" in member_shape.serialization: member_key = member_shape.serialization["name"] - self._serialize(body, member_value, member_shape, member_key) + self._serialize(body, member_value, member_shape, member_key, mime_type) - def _serialize_type_map(self, body: dict, value: dict, shape: MapShape, key: str): + def _serialize_type_map( + self, body: dict, value: dict, shape: MapShape, key: str, mime_type: str + ): if value is None: return map_obj = {} body[key] = map_obj for sub_key, sub_value in value.items(): if sub_value is not None: - self._serialize(map_obj, sub_value, shape.value, sub_key) + self._serialize(map_obj, sub_value, shape.value, sub_key, mime_type) - def _serialize_type_list(self, body: dict, value: list, shape: ListShape, key: str): + def _serialize_type_list( + self, body: dict, value: list, shape: ListShape, key: str, mime_type: str + ): if value is None: return list_obj = [] @@ -1144,19 +1282,24 @@ def _serialize_type_list(self, body: dict, value: list, shape: ListShape, key: s # setting a key on a dict. We handle this by using # a __current__ key on a wrapper dict to serialize each # list item before appending it to the serialized list. - self._serialize(wrapper, list_item, shape.member, "__current__") + self._serialize(wrapper, list_item, shape.member, "__current__", mime_type) list_obj.append(wrapper["__current__"]) - def _default_serialize(self, body: dict, value: Any, _, key: str): + def _default_serialize(self, body: dict, value: Any, _, key: str, __): body[key] = value - def _serialize_type_timestamp(self, body: dict, value: Any, shape: Shape, key: str): + def _serialize_type_timestamp(self, body: dict, value: Any, shape: Shape, key: str, _): body[key] = self._convert_timestamp_to_str( value, shape.serialization.get("timestampFormat") ) - def _serialize_type_blob(self, body: dict, value: Union[str, bytes], _, key: str): - body[key] = self._get_base64(value) + def _serialize_type_blob( + self, body: dict, value: Union[str, bytes], _, key: str, mime_type: str + ): + if mime_type in self.CBOR_TYPES: + body[key] = value + else: + body[key] = self._get_base64(value) def _prepare_additional_traits_in_response( self, response: HttpResponse, operation_model: OperationModel @@ -1174,7 +1317,9 @@ class RestJSONResponseSerializer(BaseRestResponseSerializer, JSONResponseSeriali (for the JSOn body response serialization). """ - def _serialize_content_type(self, serialized: HttpResponse, shape: Shape, shape_members: dict): + def _serialize_content_type( + self, serialized: HttpResponse, shape: Shape, shape_members: dict, mime_type: str + ): """Set Content-Type to application/json for all structured bodies.""" payload = shape.serialization.get("payload") if shape is not None else None if self._has_streaming_payload(payload, shape_members): @@ -1184,7 +1329,7 @@ def _serialize_content_type(self, serialized: HttpResponse, shape: Shape, shape_ has_body = serialized.data != b"" has_content_type = self._has_header("Content-Type", serialized.headers) if has_body and not has_content_type: - serialized.headers["Content-Type"] = "application/json" + serialized.headers["Content-Type"] = mime_type class S3ResponseSerializer(RestXMLResponseSerializer): @@ -1199,6 +1344,7 @@ def _serialize_error( response: HttpResponse, shape: StructureShape, operation_model: OperationModel, + mime_type: str, ) -> None: attr = ( {"xmlns": operation_model.metadata.get("xmlNamespace")} @@ -1206,11 +1352,12 @@ def _serialize_error( else {} ) root = ETree.Element("Error", attr) - self._add_error_tags(error, root) + self._add_error_tags(error, root, mime_type) request_id = ETree.SubElement(root, "RequestId") request_id.text = gen_amzn_requestid_long() - self._add_additional_error_tags(error, root, shape) - response.set_response(self._encode_payload(self._xml_to_string(root))) + self._add_additional_error_tags(error, root, shape, mime_type) + + response.set_response(self._encode_payload(self._node_to_string(root, mime_type))) class SqsResponseSerializer(QueryResponseSerializer): @@ -1227,17 +1374,17 @@ class SqsResponseSerializer(QueryResponseSerializer): - These double-escapes are corrected by replacing such strings with their original. """ - def _default_serialize(self, xmlnode: ETree.Element, params: str, _, name: str) -> None: + def _default_serialize(self, xmlnode: ETree.Element, params: str, _, name: str, __) -> None: """Ensures that XML text nodes use HTML entities instead of " or \r""" node = ETree.SubElement(xmlnode, name) node.text = str(params).replace('"', """).replace("\r", " ") - def _xml_to_string(self, root: Optional[ETree.ElementTree]) -> Optional[str]: + def _node_to_string(self, root: Optional[ETree.ElementTree], mime_type: str) -> Optional[str]: """ Replaces the double-escaped HTML entities with their correct HTML entity (basically reverts the escaping in the serialization of the used XML framework). """ - generated_string = super()._xml_to_string(root) + generated_string = super()._node_to_string(root, mime_type) return ( to_bytes( to_str(generated_string) diff --git a/localstack/aws/skeleton.py b/localstack/aws/skeleton.py index 558362af4d936..9cc0b03de0829 100644 --- a/localstack/aws/skeleton.py +++ b/localstack/aws/skeleton.py @@ -171,7 +171,7 @@ def dispatch_request(self, context: RequestContext, instance: ServiceRequest) -> context.service_response = result # Serialize result dict to an HTTPResponse and return it - return self.serializer.serialize_to_response(result, operation) + return self.serializer.serialize_to_response(result, operation, context.request.headers) def on_service_exception( self, context: RequestContext, exception: ServiceException @@ -185,7 +185,9 @@ def on_service_exception( """ context.service_exception = exception - return self.serializer.serialize_error_to_response(exception, context.operation) + return self.serializer.serialize_error_to_response( + exception, context.operation, context.request.headers + ) def on_not_implemented_error(self, context: RequestContext) -> HttpResponse: """ @@ -211,4 +213,4 @@ def on_not_implemented_error(self, context: RequestContext) -> HttpResponse: ) context.service_exception = error - return serializer.serialize_error_to_response(error, operation) + return serializer.serialize_error_to_response(error, operation, context.request.headers) diff --git a/localstack/constants.py b/localstack/constants.py index 63e5c483bc2dc..097f989ad9f02 100644 --- a/localstack/constants.py +++ b/localstack/constants.py @@ -90,6 +90,7 @@ # content types / encodings HEADER_CONTENT_TYPE = "Content-Type" +TEXT_XML = "text/xml" APPLICATION_AMZ_JSON_1_0 = "application/x-amz-json-1.0" APPLICATION_AMZ_JSON_1_1 = "application/x-amz-json-1.1" APPLICATION_AMZ_CBOR_1_1 = "application/x-amz-cbor-1.1" diff --git a/localstack/services/kinesis/provider.py b/localstack/services/kinesis/provider.py index 0e2b684c4ee94..ca89ce0185449 100644 --- a/localstack/services/kinesis/provider.py +++ b/localstack/services/kinesis/provider.py @@ -52,15 +52,6 @@ LOG = logging.getLogger(__name__) -# TODO ASF: Check if we need to implement CBOR encoding in the serializer! -# TODO ASF: Set "X-Amzn-Errortype" (HEADER_AMZN_ERROR_TYPE) on responses -# TODO ASF: Rewrite responses -# - Region in content of responses -# - Record rewriting: -# - SDKv2: Transform timestamps to int? -# - Remove double quotes for JSON responses -# - Convert base64 encoded data back to bytes for the cbor encoding - class KinesisBackend(RegionBackend): def __init__(self): diff --git a/localstack/services/moto.py b/localstack/services/moto.py index 662d3ddc4e00b..61e42f8c2faef 100644 --- a/localstack/services/moto.py +++ b/localstack/services/moto.py @@ -2,8 +2,8 @@ This module provides tools to call moto using moto and botocore internals without going through the moto HTTP server. """ import sys -from functools import lru_cache -from typing import Callable +from functools import lru_cache, partial +from typing import Callable, Optional, Union from moto.backends import get_backend as get_moto_backend from moto.core.exceptions import RESTError @@ -21,8 +21,11 @@ ServiceRequest, ServiceResponse, ) -from localstack.aws.client import parse_response, raise_service_exception -from localstack.aws.forwarder import ForwardingFallbackDispatcher, create_aws_request_context +from localstack.aws.forwarder import ( + ForwardingFallbackDispatcher, + create_aws_request_context, + dispatch_to_backend, +) from localstack.aws.skeleton import DispatchTable from localstack.http import Response @@ -31,23 +34,6 @@ user_agent = f"Localstack/{localstack_version} Python/{sys.version.split(' ')[0]}" -def call_moto(context: RequestContext, include_response_metadata=False) -> ServiceResponse: - """ - Call moto with the given request context and receive a parsed ServiceResponse. - - :param context: the request context - :param include_response_metadata: whether to include botocore's "ResponseMetadata" attribute - :return: an AWS ServiceResponse (same as a service provider would return) - :raises ServiceException: if moto returned an error response - """ - status, headers, content = dispatch_to_moto(context) - response = Response(content, status, headers) - parsed_response = parse_response(context.operation, response, include_response_metadata) - raise_service_exception(response, parsed_response) - - return parsed_response - - def call_moto_with_request( context: RequestContext, service_request: ServiceRequest ) -> ServiceResponse: @@ -72,18 +58,19 @@ def call_moto_with_request( return call_moto(local_context) -def proxy_moto(context: RequestContext, service_request: ServiceRequest = None) -> Response: +def proxy_moto( + context: RequestContext, service_request: ServiceRequest = None +) -> Optional[Union[ServiceResponse]]: """ Similar to ``call``, only that ``proxy`` does not parse the HTTP response into a ServiceResponse, but instead returns directly the HTTP response. This can be useful to pass through moto's response directly to the client. :param context: the request context :param service_request: currently not being used, added to satisfy ServiceRequestHandler contract - :return: the Response from moto + :return: the Response from moto or the ServiceResponse dictionary (to be serialized again) in case the Content-Type + of the response does not explicitly match the Accept header of the request """ - status, headers, content = dispatch_to_moto(context) - - return Response(response=content, status=status, headers=headers) + return dispatch_to_backend(context, dispatch_to_moto) def MotoFallbackDispatcher(provider: object) -> DispatchTable: @@ -111,7 +98,8 @@ def dispatch_to_moto(context: RequestContext) -> Response: dispatch = get_dispatcher(service.service_name, request.path) try: - return dispatch(request, request.url, request.headers) + status, headers, content = dispatch(request, request.url, request.headers) + return Response(content, status, headers) except RESTError as e: raise CommonServiceException(e.error_type, e.message, status_code=e.code) from e @@ -171,3 +159,14 @@ def load_moto_routing_table(service: str) -> Map: url_map.add(Rule(url_path, endpoint=endpoint, strict_slashes=strict_slashes)) return url_map + + +call_moto = partial(dispatch_to_backend, http_request_dispatcher=dispatch_to_moto) +""" +Call moto with the given request context and receive a parsed ServiceResponse. + +:param context: the request context +:param include_response_metadata: whether to include botocore's "ResponseMetadata" attribute +:return: an AWS ServiceResponse (same as a service provider would return) +:raises ServiceException: if moto returned an error response +""" diff --git a/localstack/services/providers.py b/localstack/services/providers.py index 290f1d1493d37..2a1b14719a5ef 100644 --- a/localstack/services/providers.py +++ b/localstack/services/providers.py @@ -142,10 +142,10 @@ def iam(): @aws_provider() def sts(): - from localstack.services.sts.provider import StsAwsApiListener + from localstack.services.sts.provider import StsProvider - listener = StsAwsApiListener() - return Service("sts", listener=listener) + provider = StsProvider() + return Service("sts", listener=AwsApiListener("sts", MotoFallbackDispatcher(provider))) @aws_provider(api="kinesis", name="legacy") diff --git a/localstack/services/sqs/query_api.py b/localstack/services/sqs/query_api.py index 2ff4b33acb8ff..209f8241519fa 100644 --- a/localstack/services/sqs/query_api.py +++ b/localstack/services/sqs/query_api.py @@ -119,13 +119,13 @@ def handle_request(request: Request, region: str) -> Response: try: response, operation = try_call_sqs(request, region) del response["ResponseMetadata"] - return serializer.serialize_to_response(response, operation) + return serializer.serialize_to_response(response, operation, request.headers) except UnknownOperationException: return Response("", 404) except CommonServiceException as e: # use a dummy operation for the serialization to work op = service.operation_model(service.operation_names[0]) - return serializer.serialize_error_to_response(e, op) + return serializer.serialize_error_to_response(e, op, request.headers) except Exception as e: LOG.exception("exception") op = service.operation_model(service.operation_names[0]) @@ -134,6 +134,7 @@ def handle_request(request: Request, region: str) -> Response: "InternalError", f"An internal error ocurred: {e}", status_code=500 ), op, + request.headers, ) diff --git a/localstack/services/sts/provider.py b/localstack/services/sts/provider.py index e4c3a7e27503b..dced9875a774b 100644 --- a/localstack/services/sts/provider.py +++ b/localstack/services/sts/provider.py @@ -1,18 +1,9 @@ import logging -import re - -import xmltodict from localstack.aws.api import RequestContext from localstack.aws.api.sts import GetCallerIdentityResponse, StsApi -from localstack.aws.proxy import AwsApiListener -from localstack.constants import APPLICATION_JSON -from localstack.http import Request, Response -from localstack.services.moto import MotoFallbackDispatcher, call_moto +from localstack.services.moto import call_moto from localstack.services.plugins import ServiceLifecycleHook -from localstack.utils.strings import to_str -from localstack.utils.time import parse_timestamp -from localstack.utils.xml import strip_xmlns LOG = logging.getLogger(__name__) @@ -23,32 +14,3 @@ def get_caller_identity(self, context: RequestContext) -> GetCallerIdentityRespo if "user/moto" in result["Arn"] and "sts" in result["Arn"]: result["Arn"] = f"arn:aws:iam::{result['Account']}:root" return result - - -class StsAwsApiListener(AwsApiListener): - def __init__(self): - self.provider = StsProvider() - super().__init__("sts", MotoFallbackDispatcher(self.provider)) - - def request(self, request: Request) -> Response: - response = super().request(request) - - if request.headers.get("Accept") == APPLICATION_JSON: - # convert "Expiration" to int for JSON response format (tested against AWS) - # TODO: introduce a proper/generic approach that works across arbitrary date fields in JSON - - def _replace(match): - timestamp = parse_timestamp(match.group(1).strip()) - return f"{int(timestamp.timestamp())}" - - def _replace_response_content(_pattern, _replacement): - content = to_str(response.data or "") - data = re.sub(_pattern, _replacement, content) - content = xmltodict.parse(data) - stripped_content = strip_xmlns(content) - response.set_json(stripped_content) - - pattern = r"([^<]+)" - _replace_response_content(pattern, _replace) - - return response diff --git a/tests/integration/apigateway_fixtures.py b/tests/integration/apigateway_fixtures.py index 68cbcb6b85348..2c92b7a97c0b8 100644 --- a/tests/integration/apigateway_fixtures.py +++ b/tests/integration/apigateway_fixtures.py @@ -93,7 +93,7 @@ def create_rest_resource_method(apigateway_client, **kwargs): def create_rest_authorizer(apigateway_client, **kwargs): response = apigateway_client.create_authorizer(**kwargs) - assert_response_is_200(response) + assert_response_is_201(response) return response.get("id"), response.get("type") @@ -127,12 +127,12 @@ def create_rest_api_integration_response(apigateway_client, **kwargs): def create_domain_name(apigateway_client, **kwargs): response = apigateway_client.create_domain_name(**kwargs) - assert_response_is_200(response) + assert_response_is_201(response) def create_base_path_mapping(apigateway_client, **kwargs): response = apigateway_client.create_base_path_mapping(**kwargs) - assert_response_is_200(response) + assert_response_is_201(response) return response.get("basePath"), response.get("stage") diff --git a/tests/integration/test_apigateway.py b/tests/integration/test_apigateway.py index 02dada9379b2c..9549df0e31e8a 100644 --- a/tests/integration/test_apigateway.py +++ b/tests/integration/test_apigateway.py @@ -856,7 +856,7 @@ def test_api_gateway_handle_domain_name(self): domain_name = f"{short_uid()}.example.com" apigw_client = aws_stack.create_external_boto_client("apigateway") rs = apigw_client.create_domain_name(domainName=domain_name) - assert 200 == rs["ResponseMetadata"]["HTTPStatusCode"] + assert 201 == rs["ResponseMetadata"]["HTTPStatusCode"] rs = apigw_client.get_domain_name(domainName=domain_name) assert 200 == rs["ResponseMetadata"]["HTTPStatusCode"] assert domain_name == rs["domainName"] diff --git a/tests/integration/test_apigateway_api.py b/tests/integration/test_apigateway_api.py index 3b8572c21ed0f..16a6a8ca4740c 100644 --- a/tests/integration/test_apigateway_api.py +++ b/tests/integration/test_apigateway_api.py @@ -259,6 +259,7 @@ def test_integration_response(apigateway_client): # this is hard to match against, so remove it response["ResponseMetadata"].pop("HTTPHeaders", None) response["ResponseMetadata"].pop("RetryAttempts", None) + response["ResponseMetadata"].pop("RequestId", None) assert response == ( { "statusCode": "200", @@ -274,6 +275,7 @@ def test_integration_response(apigateway_client): # this is hard to match against, so remove it response["ResponseMetadata"].pop("HTTPHeaders", None) response["ResponseMetadata"].pop("RetryAttempts", None) + response["ResponseMetadata"].pop("RequestId", None) assert response == ( { "statusCode": "200", @@ -287,6 +289,7 @@ def test_integration_response(apigateway_client): # this is hard to match against, so remove it response["ResponseMetadata"].pop("HTTPHeaders", None) response["ResponseMetadata"].pop("RetryAttempts", None) + response["ResponseMetadata"].pop("RequestId", None) assert response["methodIntegration"]["integrationResponses"] == ( { "200": { @@ -338,6 +341,7 @@ def test_integration_response(apigateway_client): # this is hard to match against, so remove it response["ResponseMetadata"].pop("HTTPHeaders", None) response["ResponseMetadata"].pop("RetryAttempts", None) + response["ResponseMetadata"].pop("RequestId", None) assert response == ( { "statusCode": "200", @@ -354,6 +358,7 @@ def test_integration_response(apigateway_client): # this is hard to match against, so remove it response["ResponseMetadata"].pop("HTTPHeaders", None) response["ResponseMetadata"].pop("RetryAttempts", None) + response["ResponseMetadata"].pop("RequestId", None) assert response == ( { "statusCode": "200", @@ -402,6 +407,7 @@ def test_put_integration_response_with_response_template(apigateway_client): # this is hard to match against, so remove it response["ResponseMetadata"].pop("HTTPHeaders", None) response["ResponseMetadata"].pop("RetryAttempts", None) + response["ResponseMetadata"].pop("RequestId", None) assert response == { "statusCode": "200", "selectionPattern": "foobar", diff --git a/tests/integration/test_moto.py b/tests/integration/test_moto.py index 4cafb66b291a9..a824f8eafbf21 100644 --- a/tests/integration/test_moto.py +++ b/tests/integration/test_moto.py @@ -6,7 +6,7 @@ from localstack.aws.api import ServiceException, handler from localstack.services import moto from localstack.services.moto import MotoFallbackDispatcher -from localstack.utils.common import short_uid, to_str +from localstack.utils.common import short_uid def test_call_with_sqs_creates_state_correctly(): @@ -192,31 +192,28 @@ def test_call_multi_region_backends(): del sqs_backends["eu-central-1"].queues[qname_eu] -def test_proxy_with_sqs_invalid_call_returns_error(): - response = moto.proxy_moto( - moto.create_aws_request_context( - "sqs", - "DeleteQueue", - { - "QueueUrl": "http://0.0.0.0/nonexistingqueue", - }, +def test_proxy_with_sqs_invalid_call_raises_exception(): + with pytest.raises(ServiceException): + moto.proxy_moto( + moto.create_aws_request_context( + "sqs", + "DeleteQueue", + { + "QueueUrl": "http://0.0.0.0/nonexistingqueue", + }, + ) ) - ) - - assert response.status_code == 400 - assert "NonExistentQueue" in to_str(response.data) -def test_proxy_with_sqs_returns_http_response(): +def test_proxy_with_sqs_returns_service_response(): qname = f"queue-{short_uid()}" - response = moto.proxy_moto( + create_queue_response = moto.proxy_moto( moto.create_aws_request_context("sqs", "CreateQueue", {"QueueName": qname}) ) - assert response.status_code == 200 - assert f"{qname}" in to_str(response.data) - assert "x-amzn-requestid" in response.headers + assert "QueueUrl" in create_queue_response + assert create_queue_response["QueueUrl"].endswith(qname) class FakeSqsApi: @@ -252,14 +249,14 @@ def _dispatch(action, params): return dispatcher[action](context, params) qname = f"queue-{short_uid()}" - # when falling through the dispatcher returns an HTTP response - http_response = _dispatch("CreateQueue", {"QueueName": qname}) - assert http_response.status_code == 200 + # when falling through the dispatcher returns the appropriate ServiceResponse (in this case a CreateQueueResult) + create_queue_response = _dispatch("CreateQueue", {"QueueName": qname}) + assert "QueueUrl" in create_queue_response - # this returns an - response = _dispatch("ListQueues", None) + # this returns a ListQueuesResult + list_queues_response = _dispatch("ListQueues", None) assert len(provider.calls) == 1 - assert len([url for url in response["QueueUrls"] if qname in url]) + assert len([url for url in list_queues_response["QueueUrls"] if qname in url]) def test_request_with_response_header_location_fields(): diff --git a/tests/integration/test_s3control.py b/tests/integration/test_s3control.py index e3dc998289234..be05804b4506b 100644 --- a/tests/integration/test_s3control.py +++ b/tests/integration/test_s3control.py @@ -26,7 +26,7 @@ def test_lifecycle_public_access_block(): AccountId=get_aws_account_id(), PublicAccessBlockConfiguration=access_block_config ) - assert put_response["ResponseMetadata"]["HTTPStatusCode"] == 201 + assert put_response["ResponseMetadata"]["HTTPStatusCode"] == 200 get_response = s3control_client.get_public_access_block(AccountId=get_aws_account_id()) assert access_block_config == get_response["PublicAccessBlockConfiguration"] diff --git a/tests/unit/aws/handlers/service.py b/tests/unit/aws/handlers/service.py index 638f8aefcf896..d748bbc8ad3a5 100644 --- a/tests/unit/aws/handlers/service.py +++ b/tests/unit/aws/handlers/service.py @@ -26,7 +26,7 @@ def test_parse_response(self, service_response_handler_chain): context = create_aws_request_context("sqs", "CreateQueue", {"QueueName": "foobar"}) backend_response = {"QueueUrl": "http://localhost:4566/000000000000/foobar"} http_response = create_serializer(context.service).serialize_to_response( - backend_response, context.operation + backend_response, context.operation, context.request.headers ) service_response_handler_chain.handle(context, http_response) @@ -36,7 +36,7 @@ def test_parse_response_with_streaming_response(self, service_response_handler_c context = create_aws_request_context("s3", "GetObject", {"Bucket": "foo", "Key": "bar.bin"}) backend_response = {"Body": b"\x00\x01foo", "ContentType": "application/octet-stream"} http_response = create_serializer(context.service).serialize_to_response( - backend_response, context.operation + backend_response, context.operation, context.request.headers ) service_response_handler_chain.handle(context, http_response) @@ -63,7 +63,7 @@ def test_service_exception(self, service_response_handler_chain): context.service_exception = ResourceAlreadyExistsException("oh noes") response = create_serializer(context.service).serialize_error_to_response( - context.service_exception, context.operation + context.service_exception, context.operation, context.request.headers ) service_response_handler_chain.handle(context, response) @@ -83,7 +83,7 @@ def test_service_exception_with_code_from_spec(self, service_response_handler_ch context.service_exception = QueueDoesNotExist() response = create_serializer(context.service).serialize_error_to_response( - context.service_exception, context.operation + context.service_exception, context.operation, context.request.headers ) service_response_handler_chain.handle(context, response) diff --git a/tests/unit/aws/protocol/test_serializer.py b/tests/unit/aws/protocol/test_serializer.py index 0353d2321a0ad..93a13ce634b59 100644 --- a/tests/unit/aws/protocol/test_serializer.py +++ b/tests/unit/aws/protocol/test_serializer.py @@ -76,7 +76,7 @@ def _botocore_serializer_integration_test( # The serializer changes the incoming dict, therefore copy it before passing it to the serializer response_to_parse = copy.deepcopy(response) serialized_response = response_serializer.serialize_to_response( - response_to_parse, service.operation_model(action) + response_to_parse, service.operation_model(action), None ) # Use the parser from botocore to parse the serialized response @@ -141,7 +141,7 @@ def _botocore_error_serializer_integration_test( # Use our serializer to serialize the response response_serializer = create_serializer(service) serialized_response = response_serializer.serialize_error_to_response( - exception, service.operation_model(action) + exception, service.operation_model(action), None ) # Use the parser from botocore to parse the serialized response @@ -430,7 +430,7 @@ def test_query_protocol_error_serialization_plain(): # Use our serializer to serialize the response response_serializer = create_serializer(service) serialized_response = response_serializer.serialize_error_to_response( - exception, service.operation_model("ChangeMessageVisibility") + exception, service.operation_model("ChangeMessageVisibility"), None ) serialized_response_dict = serialized_response.to_readonly_response_dict() # Replace the random request ID with a static value for comparison @@ -615,7 +615,7 @@ def test_json_protocol_error_serialization_with_shaped_default_members_on_root() service = load_service("dynamodb") response_serializer = create_serializer(service) serialized_response = response_serializer.serialize_error_to_response( - exception, service.operation_model("ExecuteTransaction") + exception, service.operation_model("ExecuteTransaction"), None ) body = serialized_response.data parsed_body = json.loads(body) @@ -652,7 +652,7 @@ def test_rest_json_protocol_error_serialization_with_shaped_default_members_on_r service = load_service("lambda") response_serializer = create_serializer(service) serialized_response = response_serializer.serialize_error_to_response( - exception, service.operation_model("GetLayerVersion") + exception, service.operation_model("GetLayerVersion"), None ) body = serialized_response.data parsed_body = json.loads(body) @@ -687,7 +687,7 @@ def test_query_protocol_error_serialization_with_default_members_not_on_root(): service = load_service("sns") response_serializer = create_serializer(service) serialized_response = response_serializer.serialize_error_to_response( - exception, service.operation_model("VerifySMSSandboxPhoneNumber") + exception, service.operation_model("VerifySMSSandboxPhoneNumber"), None ) body = serialized_response.data parser = ElementTree.XMLParser(target=ElementTree.TreeBuilder()) @@ -702,7 +702,7 @@ def test_rest_xml_protocol_error_serialization_with_default_members_not_on_root( service = load_service("route53") response_serializer = create_serializer(service) serialized_response = response_serializer.serialize_error_to_response( - exception, service.operation_model("DeleteHostedZone") + exception, service.operation_model("DeleteHostedZone"), None ) body = serialized_response.data parser = ElementTree.XMLParser(target=ElementTree.TreeBuilder()) @@ -740,7 +740,7 @@ def test_json_protocol_content_type_1_0(): service = load_service("apprunner") response_serializer = create_serializer(service) result: Response = response_serializer.serialize_to_response( - {}, service.operation_model("DeleteConnection") + {}, service.operation_model("DeleteConnection"), None ) assert result is not None assert result.content_type is not None @@ -752,7 +752,7 @@ def test_json_protocol_content_type_1_1(): service = load_service("logs") response_serializer = create_serializer(service) result: Response = response_serializer.serialize_to_response( - {}, service.operation_model("DeleteLogGroup") + {}, service.operation_model("DeleteLogGroup"), None ) assert result is not None assert result.content_type is not None @@ -1206,7 +1206,7 @@ def test_ec2_protocol_errors_have_response_root_element(): service = load_service("ec2") response_serializer = create_serializer(service) serialized_response = response_serializer.serialize_error_to_response( - exception, service.operation_model("DescribeSubnets") + exception, service.operation_model("DescribeSubnets"), None ) body = serialized_response.data parser = ElementTree.XMLParser(target=ElementTree.TreeBuilder()) @@ -1221,7 +1221,7 @@ def test_restxml_s3_errors_have_error_root_element(): service = load_service("s3") response_serializer = create_serializer(service) serialized_response = response_serializer.serialize_error_to_response( - exception, service.operation_model("GetObject") + exception, service.operation_model("GetObject"), None ) body = serialized_response.data parser = ElementTree.XMLParser(target=ElementTree.TreeBuilder()) @@ -1373,7 +1373,7 @@ def event_generator() -> Iterator: service = load_service("kinesis") operation_model = service.operation_model("SubscribeToShard") response_serializer = create_serializer(service) - serialized_response = response_serializer.serialize_to_response(response, operation_model) + serialized_response = response_serializer.serialize_to_response(response, operation_model, None) # Convert the Werkzeug response from our serializer to a response botocore can work with urllib_response = UrlLibHttpResponse( @@ -1546,7 +1546,7 @@ def test_no_mutation_of_parameters(): # serialize response and check whether parameters are unchanged _ = response_serializer.serialize_to_response( - parameters, service.operation_model("CreateHostedConfigurationVersion") + parameters, service.operation_model("CreateHostedConfigurationVersion"), None ) assert parameters == expected @@ -1559,7 +1559,7 @@ def test_serializer_error_on_protocol_error_invalid_exception(): with pytest.raises(ProtocolSerializerError): # a known protocol error would be if we try to serialize an exception which is not a CommonServiceException and # also not a generated exception - serializer.serialize_error_to_response(NotImplementedError(), operation_model) + serializer.serialize_error_to_response(NotImplementedError(), operation_model, None) def test_serializer_error_on_protocol_error_invalid_data(): @@ -1569,7 +1569,9 @@ def test_serializer_error_on_protocol_error_invalid_data(): serializer = QueryResponseSerializer() with pytest.raises(ProtocolSerializerError): serializer.serialize_to_response( - {"StreamDescription": {"CreationRequestDateTime": "invalid_timestamp"}}, operation_model + {"StreamDescription": {"CreationRequestDateTime": "invalid_timestamp"}}, + operation_model, + None, ) @@ -1586,7 +1588,7 @@ def raise_error(*args, **kwargs): serializer._serialize_response = raise_error with pytest.raises(UnknownSerializerError): - serializer.serialize_to_response({}, operation_model) + serializer.serialize_to_response({}, operation_model, None) class ComparableBytesIO(BytesIO): From 1c06be6d583d82dbd3c0ea1ffb4f37f32334c9e2 Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Mon, 22 Aug 2022 14:40:17 +0200 Subject: [PATCH 07/16] remove Logs AwsApiListener extension --- localstack/services/logs/provider.py | 21 +-------------------- localstack/services/providers.py | 5 +++-- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/localstack/services/logs/provider.py b/localstack/services/logs/provider.py index f7bae29cff6a3..aef1639ebd3bf 100644 --- a/localstack/services/logs/provider.py +++ b/localstack/services/logs/provider.py @@ -23,10 +23,7 @@ PutLogEventsResponse, SequenceToken, ) -from localstack.aws.proxy import AwsApiListener -from localstack.constants import APPLICATION_AMZ_JSON_1_1 -from localstack.services.messages import Request, Response -from localstack.services.moto import MotoFallbackDispatcher, call_moto +from localstack.services.moto import call_moto from localstack.services.plugins import ServiceLifecycleHook from localstack.utils.aws import aws_stack from localstack.utils.common import is_number @@ -75,22 +72,6 @@ def put_log_events( return call_moto(context) -class LogsAwsApiListener(AwsApiListener): - def __init__(self): - self.provider = self._create_provider() - super().__init__("logs", MotoFallbackDispatcher(self.provider)) - - def _create_provider(self): - return LogsProvider() - - def request(self, request: Request) -> Response: - response = super().request(request) - # Fix Incorrect response content-type header from cloudwatch logs #1343. - # True for all logs api responses. - response.headers["content-type"] = APPLICATION_AMZ_JSON_1_1 - return response - - def get_pattern_matcher(pattern: str) -> Callable[[str, Dict], bool]: """Returns a pattern matcher. Can be patched by plugins to return a more sophisticated pattern matcher.""" return lambda _pattern, _log_event: True diff --git a/localstack/services/providers.py b/localstack/services/providers.py index 2a1b14719a5ef..8002ac8dd042a 100644 --- a/localstack/services/providers.py +++ b/localstack/services/providers.py @@ -211,9 +211,10 @@ def awslambda_asf(): @aws_provider() def logs(): - from localstack.services.logs.provider import LogsAwsApiListener + from localstack.services.logs.provider import LogsProvider - listener = LogsAwsApiListener() + provider = LogsProvider() + listener = AwsApiListener("logs", MotoFallbackDispatcher(provider)) return Service("logs", listener=listener) From 06716c55e3e00a93f8b68f6dbf289fbcf8790163 Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Tue, 23 Aug 2022 14:06:56 +0200 Subject: [PATCH 08/16] add tests for content-negotiation in ASF serializer --- localstack/aws/protocol/serializer.py | 2 +- tests/unit/aws/protocol/test_serializer.py | 130 ++++++++++++++++++++- 2 files changed, 129 insertions(+), 3 deletions(-) diff --git a/localstack/aws/protocol/serializer.py b/localstack/aws/protocol/serializer.py index 33974576e8b9f..dce427220796e 100644 --- a/localstack/aws/protocol/serializer.py +++ b/localstack/aws/protocol/serializer.py @@ -439,7 +439,7 @@ def _create_default_response( def _get_mime_type(self, headers: Optional[Dict | Headers]) -> str: """ Extracts the accepted mime type from the request headers and returns a matching, supported mime type for the - serializer. + serializer or the default mime type of the service if there is no match. :param headers: to extract the "Accept" header from :return: preferred mime type to be used by the serializer (if it is not accepted by the client, an error is logged) diff --git a/tests/unit/aws/protocol/test_serializer.py b/tests/unit/aws/protocol/test_serializer.py index 93a13ce634b59..8278868f6baac 100644 --- a/tests/unit/aws/protocol/test_serializer.py +++ b/tests/unit/aws/protocol/test_serializer.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Iterator, List, Optional from xml.etree import ElementTree +import cbor2 import pytest from botocore.awsrequest import HeadersDict from botocore.endpoint import convert_to_response_dict @@ -14,6 +15,7 @@ from dateutil.tz import tzlocal, tzutc from requests.models import Response as RequestsResponse from urllib3 import HTTPResponse as UrlLibHttpResponse +from werkzeug.datastructures import Headers from werkzeug.wrappers import ResponseStream from localstack.aws.api import CommonServiceException, ServiceException @@ -22,6 +24,7 @@ CancellationReason, TransactionCanceledException, ) +from localstack.aws.api.kinesis import GetRecordsOutput, Record from localstack.aws.api.lambda_ import ResourceNotFoundException from localstack.aws.api.route53 import NoSuchHostedZone from localstack.aws.api.sns import VerificationException @@ -30,6 +33,7 @@ ReceiptHandleIsInvalid, UnsupportedOperation, ) +from localstack.aws.api.sts import Credentials, GetSessionTokenResponse from localstack.aws.protocol.serializer import ( ProtocolSerializerError, QueryResponseSerializer, @@ -37,6 +41,7 @@ create_serializer, ) from localstack.aws.spec import load_service +from localstack.constants import APPLICATION_AMZ_CBOR_1_1 from localstack.http.response import Response from localstack.utils.common import to_str @@ -112,7 +117,7 @@ def _botocore_error_serializer_integration_test( status_code: int, message: Optional[str], is_sender_fault: bool = False, - **additional_error_fields: Dict[str, Any] + **additional_error_fields: Dict[str, Any], ) -> dict: """ Performs an integration test for the error serialization using botocore as parser. @@ -120,7 +125,7 @@ def _botocore_error_serializer_integration_test( - Load the given service (f.e. "sqs") - Serialize the _error_ response with the appropriate serializer from the AWS Serivce Framework - Parse the serialized error response using the botocore parser - - Checks the the metadata is correct (status code, requestID,...) + - Checks if the the metadata is correct (status code, requestID,...) - Checks if the parsed error response content is correct :param service: to load the correct service specification, serializer, and parser @@ -1684,3 +1689,124 @@ def test_restjson_streaming_payload(payload): "Payload": payload, }, ) + + +@pytest.mark.parametrize( + "service,accept_header,content_type_header,expected_mime_type", + [ + # Test default S3 + ("s3", None, None, "text/xml"), + # Test default STS + ("sts", None, None, "text/xml"), + # Test STS for "any" Accept header + ("sts", "*/*", None, "text/xml"), + # Test STS for "any" Accept header and xml content + ("sts", "*/*", "text/xml", "text/xml"), + # Test STS without Accept and xml content + ("sts", None, "text/xml", "text/xml"), + # Test STS without Accept and JSON content + ("sts", None, "application/json", "application/json"), + # Test STS with JSON Accept and XML content + ("sts", "application/json", "text/xml", "application/json"), + # Test default Kinesis + ("kinesis", None, None, "application/json"), + # Test Kinesis for "any" Accept header + ("kinesis", "*/*", None, "application/json"), + # Test Kinesis for "any" Accept header and JSON content + ("kinesis", "*/*", "application/json", "application/json"), + # Test Kinesis without Accept and CBOR content + ("kinesis", None, "application/cbor", "application/cbor"), + # Test Kinesis without Accept and CBOR content + ("kinesis", None, "application/cbor", "application/cbor"), + # Test Kinesis with JSON Accept and CBOR content + ("kinesis", "application/json", "application/cbor", "application/json"), + # Test Kinesis with CBOR Accept and JSON content + ("kinesis", "application/cbor", "application/json", "application/cbor"), + # Test Kinesis with CBOR 1.1 Accept and JSON content + ("kinesis", APPLICATION_AMZ_CBOR_1_1, "application/json", APPLICATION_AMZ_CBOR_1_1), + # Test Kinesis with non-supported Accept header and without Content-Type + ("kinesis", "unknown/content-type", None, "application/json"), + # Test Kinesis with non-supported Accept header and CBOR Content-Type + ("kinesis", "unknown/content-type", "application/cbor", "application/json"), + # Test Kinesis with non-supported Content-Type + ("kinesis", None, "unknown/content-type", "application/json"), + ], +) +def test_accept_header_detection( + service: str, + accept_header: Optional[str], + content_type_header: Optional[str], + expected_mime_type: str, +): + service_model = load_service(service) + response_serializer = create_serializer(service_model) + headers = Headers() + if accept_header: + headers["Accept"] = accept_header + if content_type_header: + headers["Content-Type"] = content_type_header + mime_type = response_serializer._get_mime_type(headers) + assert ( + mime_type == expected_mime_type + ), f"Detected mime type ({mime_type}) was not as expected ({expected_mime_type})" + + +@pytest.mark.parametrize( + "headers_dict", + [{"Content-Type": "application/json"}, {"Accept": "application/json"}], +) +def test_query_protocol_json_serialization(headers_dict): + service = load_service("sts") + response_serializer = create_serializer(service) + headers = Headers(headers_dict) + utc_timestamp = 1661255665.123 + response_data = GetSessionTokenResponse( + Credentials=Credentials( + AccessKeyId="accessKeyId", + SecretAccessKey="secretAccessKey", + SessionToken="sessionToken", + Expiration=datetime.utcfromtimestamp(utc_timestamp), + ) + ) + result: Response = response_serializer.serialize_to_response( + response_data, service.operation_model("GetSessionToken"), headers + ) + assert result is not None + assert result.content_type is not None + assert result.content_type == "application/json" + parsed_data = json.loads(result.data) + # Ensure the structure is the same as for query-xml (f.e. with "SOAP"-like root element), but just JSON encoded + assert "GetSessionTokenResponse" in parsed_data + assert "ResponseMetadata" in parsed_data["GetSessionTokenResponse"] + assert "GetSessionTokenResult" in parsed_data["GetSessionTokenResponse"] + # Make sure the timestamp is formatted as str(int(utc float)) + assert parsed_data["GetSessionTokenResponse"]["GetSessionTokenResult"].get( + "Credentials", {} + ).get("Expiration") == str(int(utc_timestamp)) + + +@pytest.mark.parametrize( + "headers_dict", + [{"Content-Type": "application/cbor"}, {"Accept": "application/cbor"}], +) +def test_json_protocol_cbor_serialization(headers_dict): + service = load_service("kinesis") + response_serializer = create_serializer(service) + headers = Headers(headers_dict) + response_data = GetRecordsOutput( + Records=[ + Record( + SequenceNumber="test_sequence_number", + Data=b"test_data", + PartitionKey="test_partition_key", + ) + ] + ) + result: Response = response_serializer.serialize_to_response( + response_data, service.operation_model("GetRecords"), headers + ) + assert result is not None + assert result.content_type is not None + assert result.content_type == "application/cbor" + parsed_data = cbor2.loads(result.data) + assert parsed_data == response_data From 25c5e44103eed391f47ec8f838c42e503d5b19c0 Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Wed, 24 Aug 2022 11:10:39 +0200 Subject: [PATCH 09/16] remove StepFunctions AwsApiListener extension --- localstack/services/providers.py | 9 ++++++--- localstack/services/stepfunctions/provider.py | 12 +----------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/localstack/services/providers.py b/localstack/services/providers.py index 8002ac8dd042a..3c468df132359 100644 --- a/localstack/services/providers.py +++ b/localstack/services/providers.py @@ -365,13 +365,16 @@ def events(): @aws_provider() def stepfunctions(): - from localstack.services.stepfunctions.provider import StepFunctionsApiListener + from localstack.services.stepfunctions.provider import StepFunctionsProvider - listener = StepFunctionsApiListener() + provider = StepFunctionsProvider() + listener = AwsApiListener( + "stepfunctions", HttpFallbackDispatcher(provider, provider.get_forward_url) + ) return Service( "stepfunctions", listener=listener, - lifecycle_hook=listener.provider, + lifecycle_hook=provider, ) diff --git a/localstack/services/stepfunctions/provider.py b/localstack/services/stepfunctions/provider.py index 273790cc8bd43..9e4eb8bc4972e 100644 --- a/localstack/services/stepfunctions/provider.py +++ b/localstack/services/stepfunctions/provider.py @@ -9,8 +9,7 @@ LogLevel, StepfunctionsApi, ) -from localstack.aws.forwarder import HttpFallbackDispatcher, get_request_forwarder_http -from localstack.aws.proxy import AwsApiListener +from localstack.aws.forwarder import get_request_forwarder_http from localstack.constants import LOCALHOST from localstack.services.plugins import ServiceLifecycleHook from localstack.services.stepfunctions.stepfunctions_starter import ( @@ -19,15 +18,6 @@ ) -class StepFunctionsApiListener(AwsApiListener): - def __init__(self, provider=None): - provider = provider or StepFunctionsProvider() - self.provider = provider - super().__init__( - "stepfunctions", HttpFallbackDispatcher(provider, provider.get_forward_url) - ) - - class StepFunctionsProvider(StepfunctionsApi, ServiceLifecycleHook): def __init__(self): self.forward_request = get_request_forwarder_http(self.get_forward_url) From 7276d8ddf20647c4ca341922892d7a4346e15f87 Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Wed, 24 Aug 2022 14:20:11 +0200 Subject: [PATCH 10/16] fix Api Gateway UpdateDeployment response status code check --- tests/integration/apigateway_fixtures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/apigateway_fixtures.py b/tests/integration/apigateway_fixtures.py index 2c92b7a97c0b8..b307496d4bc32 100644 --- a/tests/integration/apigateway_fixtures.py +++ b/tests/integration/apigateway_fixtures.py @@ -144,7 +144,7 @@ def create_rest_api_deployment(apigateway_client, **kwargs): def update_rest_api_deployment(apigateway_client, **kwargs): response = apigateway_client.update_deployment(**kwargs) - assert_response_is_201(response) + assert_response_is_200(response) return response From 4c9e9b3456958b44d4fb872b72abbee3988e602e Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Thu, 25 Aug 2022 16:04:19 +0200 Subject: [PATCH 11/16] remove proxy_moto (obsolete) --- localstack/services/moto.py | 19 ++----------------- tests/integration/test_moto.py | 15 ++++----------- 2 files changed, 6 insertions(+), 28 deletions(-) diff --git a/localstack/services/moto.py b/localstack/services/moto.py index 61e42f8c2faef..668f314727882 100644 --- a/localstack/services/moto.py +++ b/localstack/services/moto.py @@ -3,7 +3,7 @@ """ import sys from functools import lru_cache, partial -from typing import Callable, Optional, Union +from typing import Callable from moto.backends import get_backend as get_moto_backend from moto.core.exceptions import RESTError @@ -58,21 +58,6 @@ def call_moto_with_request( return call_moto(local_context) -def proxy_moto( - context: RequestContext, service_request: ServiceRequest = None -) -> Optional[Union[ServiceResponse]]: - """ - Similar to ``call``, only that ``proxy`` does not parse the HTTP response into a ServiceResponse, but instead - returns directly the HTTP response. This can be useful to pass through moto's response directly to the client. - - :param context: the request context - :param service_request: currently not being used, added to satisfy ServiceRequestHandler contract - :return: the Response from moto or the ServiceResponse dictionary (to be serialized again) in case the Content-Type - of the response does not explicitly match the Accept header of the request - """ - return dispatch_to_backend(context, dispatch_to_moto) - - def MotoFallbackDispatcher(provider: object) -> DispatchTable: """ Wraps a provider with a moto fallthrough mechanism. It does by creating a new DispatchTable from the original @@ -82,7 +67,7 @@ def MotoFallbackDispatcher(provider: object) -> DispatchTable: :param provider: the ASF provider :return: a modified DispatchTable """ - return ForwardingFallbackDispatcher(provider, proxy_moto) + return ForwardingFallbackDispatcher(provider, call_moto) def dispatch_to_moto(context: RequestContext) -> Response: diff --git a/tests/integration/test_moto.py b/tests/integration/test_moto.py index a824f8eafbf21..0f1439e5fc2f8 100644 --- a/tests/integration/test_moto.py +++ b/tests/integration/test_moto.py @@ -53,13 +53,6 @@ def test_call_non_implemented_operation(): ) -def test_proxy_non_implemented_operation(): - with pytest.raises(NotImplementedError): - moto.proxy_moto( - moto.create_aws_request_context("athena", "DeleteDataCatalog", {"Name": "foo"}) - ) - - def test_call_with_sqs_modifies_state_in_moto_backend(): """Whitebox test to check that moto backends are populated correctly""" from moto.sqs.models import sqs_backends @@ -192,9 +185,9 @@ def test_call_multi_region_backends(): del sqs_backends["eu-central-1"].queues[qname_eu] -def test_proxy_with_sqs_invalid_call_raises_exception(): +def test_call_with_sqs_invalid_call_raises_exception(): with pytest.raises(ServiceException): - moto.proxy_moto( + moto.call_moto( moto.create_aws_request_context( "sqs", "DeleteQueue", @@ -205,10 +198,10 @@ def test_proxy_with_sqs_invalid_call_raises_exception(): ) -def test_proxy_with_sqs_returns_service_response(): +def test_call_with_sqs_returns_service_response(): qname = f"queue-{short_uid()}" - create_queue_response = moto.proxy_moto( + create_queue_response = moto.call_moto( moto.create_aws_request_context("sqs", "CreateQueue", {"QueueName": qname}) ) From 251d454f6eb82b11c7b9499388268fbd60e8b019 Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Fri, 26 Aug 2022 12:31:50 +0200 Subject: [PATCH 12/16] mark kinesalite as deprecated --- localstack/config.py | 2 +- localstack/services/kinesis/provider.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/localstack/config.py b/localstack/config.py index a2401e7b370a8..3d36d587a894b 100644 --- a/localstack/config.py +++ b/localstack/config.py @@ -554,7 +554,7 @@ def in_docker(): # Delay between data persistence (in seconds) KINESIS_MOCK_PERSIST_INTERVAL = os.environ.get("KINESIS_MOCK_PERSIST_INTERVAL", "").strip() or "5s" -# Kinesis provider - either "kinesis-mock" or "kinesalite" +# Kinesis provider - either "kinesis-mock" or "kinesalite" (deprecated, kinesalite support will be removed) KINESIS_PROVIDER = os.environ.get("KINESIS_PROVIDER") or "kinesis-mock" # Whether or not to handle lambda event sources as synchronous invocations diff --git a/localstack/services/kinesis/provider.py b/localstack/services/kinesis/provider.py index ca89ce0185449..07aa06e4b746e 100644 --- a/localstack/services/kinesis/provider.py +++ b/localstack/services/kinesis/provider.py @@ -211,6 +211,7 @@ def deregister_stream_consumer( consumer_name: ConsumerName = "", consumer_arn: ConsumerARN = "", ) -> None: + # TODO remove this method when deleting kinesalite support if config.KINESIS_PROVIDER == "kinesalite": def consumer_filter(consumer: ConsumerDescription): @@ -237,6 +238,7 @@ def list_stream_consumers( max_results: ListStreamConsumersInputLimit = None, stream_creation_timestamp: Timestamp = None, ) -> ListStreamConsumersOutput: + # TODO remove this method when deleting kinesalite support if config.KINESIS_PROVIDER == "kinesalite": stream_consumers = KinesisBackend.get().stream_consumers consumers: List[Consumer] = [] @@ -260,6 +262,7 @@ def describe_stream_consumer( consumer_name: ConsumerName = None, consumer_arn: ConsumerARN = None, ) -> DescribeStreamConsumerOutput: + # TODO remove this method when deleting kinesalite support if config.KINESIS_PROVIDER == "kinesalite": consumer_to_locate = find_consumer(consumer_arn, consumer_name, stream_arn) if not consumer_to_locate: @@ -274,6 +277,7 @@ def describe_stream_consumer( def enable_enhanced_monitoring( self, context: RequestContext, stream_name: StreamName, shard_level_metrics: MetricsNameList ) -> EnhancedMonitoringOutput: + # TODO remove this method when deleting kinesalite support if config.KINESIS_PROVIDER == "kinesalite": stream_metrics = KinesisBackend.get().enhanced_metrics[stream_name] stream_metrics.update(shard_level_metrics) @@ -290,6 +294,7 @@ def enable_enhanced_monitoring( def disable_enhanced_monitoring( self, context: RequestContext, stream_name: StreamName, shard_level_metrics: MetricsNameList ) -> EnhancedMonitoringOutput: + # TODO remove this method when deleting kinesalite support if config.KINESIS_PROVIDER == "kinesalite": region = KinesisBackend.get() region.enhanced_metrics[stream_name] = region.enhanced_metrics[stream_name] - set( @@ -312,6 +317,7 @@ def update_shard_count( target_shard_count: PositiveIntegerObject, scaling_type: ScalingType, ) -> UpdateShardCountOutput: + # TODO remove this method when deleting kinesalite support if config.KINESIS_PROVIDER == "kinesalite": # Currently, kinesalite - which backs the Kinesis implementation for localstack - does # not support UpdateShardCount: https://github.com/mhart/kinesalite/issues/61 From 7d356f70756737890440803bd611c5e1ae2c4999 Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Fri, 26 Aug 2022 13:34:15 +0200 Subject: [PATCH 13/16] use _server managed by kinesis_starter --- localstack/services/kinesis/provider.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/localstack/services/kinesis/provider.py b/localstack/services/kinesis/provider.py index 07aa06e4b746e..1426dd5020c2a 100644 --- a/localstack/services/kinesis/provider.py +++ b/localstack/services/kinesis/provider.py @@ -5,6 +5,7 @@ from random import random from typing import Dict, List, Set +import localstack.services.kinesis.kinesis_starter as starter from localstack import config from localstack.aws.api import RequestContext from localstack.aws.api.kinesis import ( @@ -46,7 +47,6 @@ ) from localstack.constants import LOCALHOST from localstack.services.generic_proxy import RegionBackend -from localstack.services.kinesis.kinesis_starter import check_kinesis, start_kinesis from localstack.services.plugins import ServiceLifecycleHook from localstack.utils.aws import aws_stack @@ -83,16 +83,13 @@ def find_consumer(consumer_arn="", consumer_name="", stream_arn=""): class KinesisProvider(KinesisApi, ServiceLifecycleHook): - def __init__(self): - self._server = None - def on_before_start(self): - self._server = start_kinesis() - check_kinesis() + starter.start_kinesis() + starter.check_kinesis() def get_forward_url(self): """Return the URL of the backend Kinesis server to forward requests to""" - return f"http://{LOCALHOST}:{self._server.port}" + return f"http://{LOCALHOST}:{starter._server.port}" def subscribe_to_shard( self, From 95ed2a5dcce66a963840a94bc204a9fa6e63944e Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Mon, 29 Aug 2022 08:54:51 +0200 Subject: [PATCH 14/16] migrate #6732 to new ASF provider --- localstack/services/kinesis/provider.py | 14 ++++++++++---- tests/integration/test_kinesis.py | 4 ++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/localstack/services/kinesis/provider.py b/localstack/services/kinesis/provider.py index 1426dd5020c2a..2a4c0be3d096d 100644 --- a/localstack/services/kinesis/provider.py +++ b/localstack/services/kinesis/provider.py @@ -49,8 +49,10 @@ from localstack.services.generic_proxy import RegionBackend from localstack.services.plugins import ServiceLifecycleHook from localstack.utils.aws import aws_stack +from localstack.utils.time import now_utc LOG = logging.getLogger(__name__) +MAX_SUBSCRIPTION_SECONDS = 300 class KinesisBackend(RegionBackend): @@ -116,8 +118,10 @@ def subscribe_to_shard( def event_generator(): shard_iterator = initial_shard_iterator last_sequence_number = starting_sequence_number - # TODO: find better way to run loop up to max 5 minutes (until connection terminates)! - for i in range(5 * 60): + + maximum_duration_subscription_timestamp = now_utc() + MAX_SUBSCRIPTION_SECONDS + + while now_utc() < maximum_duration_subscription_timestamp: try: result = kinesis.get_records(ShardIterator=shard_iterator) except Exception as e: @@ -131,8 +135,10 @@ def event_generator(): shard_iterator = result.get("NextShardIterator") records = result.get("Records", []) if not records: - time.sleep(1) - continue + # On AWS there is *at least* 1 event every 5 seconds + # but this is not possible in this structure. + # In order to avoid a 5-second blocking call, we make the compromise of 3 seconds. + time.sleep(3) yield SubscribeToShardEventStream( SubscribeToShardEvent=SubscribeToShardEvent( diff --git a/tests/integration/test_kinesis.py b/tests/integration/test_kinesis.py index ad7f47e6a2d68..ea143bd53a796 100644 --- a/tests/integration/test_kinesis.py +++ b/tests/integration/test_kinesis.py @@ -7,7 +7,7 @@ import requests from localstack import config, constants -from localstack.services.kinesis import kinesis_listener +from localstack.services.kinesis import provider as kinesis_provider from localstack.utils.aws import aws_stack from localstack.utils.common import poll_condition, retry, select_attributes, short_uid from localstack.utils.kinesis import kinesis_connector @@ -295,7 +295,7 @@ def test_record_lifecycle_data_integrity( snapshot.match("Records", response_records) @pytest.mark.aws_validated - @patch.object(kinesis_listener, "MAX_SUBSCRIPTION_SECONDS", 3) + @patch.object(kinesis_provider, "MAX_SUBSCRIPTION_SECONDS", 3) def test_subscribe_to_shard_timeout( self, kinesis_client, From 5d328bfcad18f32a66a8bc77395802f6ee47b851 Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Mon, 29 Aug 2022 09:20:27 +0200 Subject: [PATCH 15/16] fix moto forwarding --- localstack/services/moto.py | 41 ++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/localstack/services/moto.py b/localstack/services/moto.py index 668f314727882..7341a33a6599c 100644 --- a/localstack/services/moto.py +++ b/localstack/services/moto.py @@ -2,8 +2,8 @@ This module provides tools to call moto using moto and botocore internals without going through the moto HTTP server. """ import sys -from functools import lru_cache, partial -from typing import Callable +from functools import lru_cache +from typing import Callable, Optional, Union from moto.backends import get_backend as get_moto_backend from moto.core.exceptions import RESTError @@ -34,6 +34,17 @@ user_agent = f"Localstack/{localstack_version} Python/{sys.version.split(' ')[0]}" +def call_moto(context: RequestContext, include_response_metadata=False) -> ServiceResponse: + """ + Call moto with the given request context and receive a parsed ServiceResponse. + + :param context: the request context + :param include_response_metadata: whether to include botocore's "ResponseMetadata" attribute + :return: a serialized AWS ServiceResponse (same as boto3 would return) + """ + return dispatch_to_backend(context, dispatch_to_moto, include_response_metadata) + + def call_moto_with_request( context: RequestContext, service_request: ServiceRequest ) -> ServiceResponse: @@ -58,6 +69,19 @@ def call_moto_with_request( return call_moto(local_context) +def _proxy_moto( + context: RequestContext, request: ServiceRequest +) -> Optional[Union[ServiceResponse, Response]]: + """ + Wraps `call_moto` such that the interface is compliant with a ServiceRequestHandler. + + :param context: the request context + :param service_request: currently not being used, added to satisfy ServiceRequestHandler contract + :return: the Response from moto + """ + return call_moto(context) + + def MotoFallbackDispatcher(provider: object) -> DispatchTable: """ Wraps a provider with a moto fallthrough mechanism. It does by creating a new DispatchTable from the original @@ -67,7 +91,7 @@ def MotoFallbackDispatcher(provider: object) -> DispatchTable: :param provider: the ASF provider :return: a modified DispatchTable """ - return ForwardingFallbackDispatcher(provider, call_moto) + return ForwardingFallbackDispatcher(provider, _proxy_moto) def dispatch_to_moto(context: RequestContext) -> Response: @@ -144,14 +168,3 @@ def load_moto_routing_table(service: str) -> Map: url_map.add(Rule(url_path, endpoint=endpoint, strict_slashes=strict_slashes)) return url_map - - -call_moto = partial(dispatch_to_backend, http_request_dispatcher=dispatch_to_moto) -""" -Call moto with the given request context and receive a parsed ServiceResponse. - -:param context: the request context -:param include_response_metadata: whether to include botocore's "ResponseMetadata" attribute -:return: an AWS ServiceResponse (same as a service provider would return) -:raises ServiceException: if moto returned an error response -""" From 198b4d4e08675d4620ec0703d7adabb192858dc1 Mon Sep 17 00:00:00 2001 From: Alexander Rashed Date: Mon, 29 Aug 2022 13:41:34 +0200 Subject: [PATCH 16/16] fix Kinesis server access --- localstack/services/kinesis/kinesis_starter.py | 4 ++++ localstack/services/kinesis/provider.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/localstack/services/kinesis/kinesis_starter.py b/localstack/services/kinesis/kinesis_starter.py index b1c8d5813b5a3..bd39ff0fe5e27 100644 --- a/localstack/services/kinesis/kinesis_starter.py +++ b/localstack/services/kinesis/kinesis_starter.py @@ -81,3 +81,7 @@ def is_kinesis_running() -> bool: if _server is None: return False return _server.is_running() + + +def get_server(): + return _server diff --git a/localstack/services/kinesis/provider.py b/localstack/services/kinesis/provider.py index 2a4c0be3d096d..f01037e09ae5f 100644 --- a/localstack/services/kinesis/provider.py +++ b/localstack/services/kinesis/provider.py @@ -91,7 +91,7 @@ def on_before_start(self): def get_forward_url(self): """Return the URL of the backend Kinesis server to forward requests to""" - return f"http://{LOCALHOST}:{starter._server.port}" + return f"http://{LOCALHOST}:{starter.get_server().port}" def subscribe_to_shard( self,