diff --git a/localstack/aws/connect.py b/localstack/aws/connect.py new file mode 100644 index 0000000000000..a17a2e9dda4f0 --- /dev/null +++ b/localstack/aws/connect.py @@ -0,0 +1,243 @@ +""" +LocalStack client stack. + +This module provides the interface to perform cross-service communication between +LocalStack providers. + + from localstack.aws.connect import connect_to + + key_pairs = connect_to('ec2').describe_key_pairs() + buckets = connect_to('s3', region='ap-south-1').list_buckets() +""" +import json +from datetime import datetime, timezone +from functools import cache +from typing import Optional, TypedDict + +from boto3.session import Session +from botocore.awsrequest import AWSPreparedRequest +from botocore.client import BaseClient +from botocore.config import Config + +from localstack import config +from localstack.aws.api import RequestContext +from localstack.constants import ( + INTERNAL_AWS_ACCESS_KEY_ID, + INTERNAL_AWS_SECRET_ACCESS_KEY, + MAX_POOL_CONNECTIONS, +) +from localstack.utils.aws.arns import extract_region_from_arn +from localstack.utils.aws.aws_stack import get_local_service_url +from localstack.utils.aws.request_context import get_region_from_request_context + +# +# Data transfer object +# + +LOCALSTACK_DATA_HEADER = "x-localstack-data" +"""Request header which contains the data transfer object.""" + + +class LocalStackData(TypedDict): + """ + LocalStack Data Transfer Object. + + This is sent with every internal request and contains any additional information + LocalStack might need for the purpose of policy enforcement. It is serialised + into text and sent in the request header. + + The keys approximately correspond to: + https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_condition-keys.html + """ + + current_time: str + """Request datetime in ISO8601 format""" + + source_arn: str + """ARN of resource which is triggering the call""" + + source_service: str + """Service principal where the call originates, eg. `ec2`""" + + target_arn: str + """ARN of the resource being targeted.""" + + +def dump_dto(data: LocalStackData) -> str: + # TODO@viren: Improve minification using custom JSONEncoder that use shortened keys + + # To produce a compact JSON representation of DTO, remove spaces from separators + return json.dumps(data, separators=(",", ":")) + + +def load_dto(data: str) -> LocalStackData: + return json.loads(data) + + +# +# Client +# + + +class ConnectFactory: + """ + Factory to build the AWS client. + """ + + def __init__( + self, + use_ssl: bool = False, + verify: bool = False, + aws_access_key_id: Optional[str] = INTERNAL_AWS_ACCESS_KEY_ID, + aws_secret_access_key: Optional[str] = INTERNAL_AWS_SECRET_ACCESS_KEY, + ): + """ + If either of the access keys are set to None, they are loaded from following + locations: + - AWS environment variables + - Credentials file `~/.aws/credentials` + - Config file `~/.aws/config` + + :param use_ssl: Whether to use SSL + :param verify: Whether to verify SSL certificates + :param aws_access_key_id: Access key to use for the client. + If set to None, loads them from botocore session. See above. + :param aws_secret_access_key: Secret key to use for the client. + If set to None, loads them from botocore session. See above. + :param localstack_data: LocalStack data transfer object + """ + self._use_ssl = use_ssl + self._verify = verify + self._aws_access_key_id = aws_access_key_id + self._aws_secret_access_key = aws_secret_access_key + self._aws_session_token = None + self._session = Session() + self._config = Config(max_pool_connections=MAX_POOL_CONNECTIONS) + + def get_partition_for_region(self, region_name: str) -> str: + """ + Return the AWS partition name for a given region, eg. `aws`, `aws-cn`, etc. + """ + return self._session.get_partition_for_region(region_name) + + def get_session_region_name(self) -> str: + """ + Return AWS region as set in the Boto session. + """ + return self._session.region_name + + def get_region_name(self) -> str: + """ + Return the AWS region name from following sources, in order of availability. + - LocalStack request context + - LocalStack default region + - Boto session + """ + return ( + get_region_from_request_context() + or config.DEFAULT_REGION + or self.get_session_region_name() + ) + + @cache + def get_client( + self, + service_name: str, + region_name: str, + use_ssl: bool, + verify: bool, + endpoint_url: str, + aws_access_key_id: str, + aws_secret_access_key: str, + aws_session_token: str, + config: Config, + ) -> BaseClient: + return self._session.client( + service_name=service_name, + region_name=region_name, + use_ssl=use_ssl, + verify=verify, + endpoint_url=endpoint_url, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + config=config, + ) + + def __call__( + self, + target_service: str, + region_name: str = None, + endpoint_url: str = None, + config: Config = None, + source_arn: str = None, + source_service: str = None, + target_arn: str = None, + ) -> BaseClient: + """ + Build and return the client. + + Presence of any attribute apart from `source_*` or `target_*` argument + indicates that this is a client meant for internal calls. + + :param target_service: Service to build the client for, eg. `s3` + :param region_name: Name of the AWS region to be associated with the client + :param endpoint_url: Full endpoint URL to be used by the client. + Defaults to appropriate LocalStack endpoint. + :param config: Boto config for advanced use. + :param source_arn: ARN of resource which triggers the call. Required for + internal calls. + :param source_service: Service name where call originates. Required for + internal calls. + :param target_arn: ARN of targeted resource. Overrides `region_name`. + Required for internal calls. + """ + localstack_data = LocalStackData() + + if source_arn: + localstack_data["source_arn"] = source_arn + + if source_service: + localstack_data["source_service"] = source_service + + if target_arn: + # Attention: region is overriden here + region_name = extract_region_from_arn(target_arn) + localstack_data["target_arn"] = target_arn + + client = self.get_client( + service_name=target_service, + region_name=region_name or self.get_region_name(), + use_ssl=self._use_ssl, + verify=self._verify, + endpoint_url=endpoint_url or get_local_service_url(target_service), + aws_access_key_id=self._aws_access_key_id, + aws_secret_access_key=self._aws_secret_access_key, + aws_session_token=self._aws_session_token, + config=config or self._config, + ) + + def _handler(request: AWSPreparedRequest, **_): + data = localstack_data | LocalStackData( + current_time=datetime.now(timezone.utc).isoformat() + ) + request.headers[LOCALSTACK_DATA_HEADER] = dump_dto(data) + + if len(localstack_data): + client.meta.events.register("before-send.*.*", handler=_handler) + + return client + + +connect_to = ConnectFactory() + +# +# Utilities +# + + +def is_internal_call(context: RequestContext) -> bool: + """ + Whether a given request is an internal LocalStack cross-service call. + """ + return LOCALSTACK_DATA_HEADER in context.request.headers diff --git a/localstack/aws/handlers/auth.py b/localstack/aws/handlers/auth.py index 47e36bef89198..7fe099d1ebe25 100644 --- a/localstack/aws/handlers/auth.py +++ b/localstack/aws/handlers/auth.py @@ -5,7 +5,11 @@ set_aws_access_key_id, set_aws_account_id, ) -from localstack.constants import TEST_AWS_ACCESS_KEY_ID +from localstack.constants import ( + INTERNAL_AWS_ACCESS_KEY_ID, + INTERNAL_AWS_ACCOUNT_ID, + TEST_AWS_ACCESS_KEY_ID, +) from localstack.http import Response from localstack.utils.aws.aws_stack import extract_access_key_id_from_auth_header @@ -45,7 +49,17 @@ def __call__(self, chain: HandlerChain, context: RequestContext, response: Respo set_aws_access_key_id(access_key_id) # Obtain the account ID and save it in the request context - context.account_id = get_account_id_from_access_key_id(access_key_id) + if access_key_id == INTERNAL_AWS_ACCESS_KEY_ID: + # For internal calls, a special account ID is used for request context + # Cross account calls don't have the same auth flows as user-originating calls + # which means there is no true Account ID. + # The invocations of `get_aws_account_id()` used to resolve the stores must not break. + # We don't use the DEFAULT_AWS_ACCOUNT_ID either to help identify bugs. + # If correctly implemented with CrossAccountAttribute and ARNs, the provider + # will work with this internal AWS account ID. + context.account_id = INTERNAL_AWS_ACCOUNT_ID + else: + context.account_id = get_account_id_from_access_key_id(access_key_id) # Save the same account ID in the thread context set_aws_account_id(context.account_id) diff --git a/localstack/constants.py b/localstack/constants.py index 31b7288fd0350..9d20d748ecc5b 100644 --- a/localstack/constants.py +++ b/localstack/constants.py @@ -45,6 +45,9 @@ # Fallback Account ID if not available in the client request DEFAULT_AWS_ACCOUNT_ID = "000000000000" +# Fallback Account ID for internal calls +INTERNAL_AWS_ACCOUNT_ID = "000012210000" + # AWS user account ID used for tests - TODO move to config.py if "TEST_AWS_ACCOUNT_ID" not in os.environ: os.environ["TEST_AWS_ACCOUNT_ID"] = DEFAULT_AWS_ACCOUNT_ID diff --git a/localstack/services/awslambda/event_source_listeners/sqs_event_source_listener.py b/localstack/services/awslambda/event_source_listeners/sqs_event_source_listener.py index 45d71a2bd0446..eb4c7c07882a8 100644 --- a/localstack/services/awslambda/event_source_listeners/sqs_event_source_listener.py +++ b/localstack/services/awslambda/event_source_listeners/sqs_event_source_listener.py @@ -63,6 +63,16 @@ def _listener_loop(self, *args): for source in sources: queue_arn = source["EventSourceArn"] region_name = extract_region_from_arn(queue_arn) + ## sqs_client = aws_stack.connect_to_service("sqs", region_name=region_name) + # `sqs` is the factory + # config args include stuff going into boto like retries, max conn pool,. + # could be called `boto_config` + # all options will map to boto_config args + # aws.sqs.configure().credentials(aws_access_key_id=...) + # client can be created from ARNs + # can communicate with Queue ARN + # all methods could be prefixed with `set_` eg. set_target_arn() + # sqs_client = aws_client().target_arn(queue_arn).credentials(sts_call_credentials).sqs() sqs_client = aws_stack.connect_to_service("sqs", region_name=region_name) batch_size = max(min(source.get("BatchSize", 1), 10), 1) diff --git a/localstack/services/s3/notifications.py b/localstack/services/s3/notifications.py index 56ce022a6234c..b76ca9699e6db 100644 --- a/localstack/services/s3/notifications.py +++ b/localstack/services/s3/notifications.py @@ -28,6 +28,7 @@ TopicArn, TopicConfiguration, ) +from localstack.aws.connect import connect_to from localstack.config import DEFAULT_REGION from localstack.services.s3.models import get_moto_s3_backend from localstack.services.s3.utils import ( @@ -350,7 +351,7 @@ def _get_arn_value_and_name(topic_configuration: TopicConfiguration) -> [TopicAr return topic_configuration.get("TopicArn", ""), "TopicArn" def _verify_target_exists(self, arn: str, arn_data: ArnData) -> None: - client = aws_stack.connect_to_service(self.service_name, region_name=arn_data["region"]) + client = connect_to(self.service_name, target_arn=arn, source_service="s3") try: client.get_topic_attributes(TopicArn=arn) except ClientError: diff --git a/localstack/services/s3/s3_listener.py b/localstack/services/s3/s3_listener.py index c7527f60960fa..94eb52f14a561 100644 --- a/localstack/services/s3/s3_listener.py +++ b/localstack/services/s3/s3_listener.py @@ -21,6 +21,7 @@ from localstack import config, constants from localstack.aws.api import CommonServiceException +from localstack.aws.connect import connect_to from localstack.config import get_protocol as get_service_protocol from localstack.services.generic_proxy import ProxyListener from localstack.services.generic_proxy import append_cors_headers as _append_default_cors_headers @@ -368,10 +369,11 @@ def send_notification_for_subscriber( ) if notification.get("Topic"): region = arns.extract_region_from_arn(notification["Topic"]) - sns_client = aws_stack.connect_to_service("sns", region_name=region) + topic_arn = notification["Topic"] + sns_client = connect_to("sns", target_arn=topic_arn, source_service="s3") try: sns_client.publish( - TopicArn=notification["Topic"], + TopicArn=topic_arn, Message=message, Subject="Amazon S3 Notification", ) diff --git a/localstack/services/ses/provider.py b/localstack/services/ses/provider.py index a0af7f6ba805d..e8f94933b45b4 100644 --- a/localstack/services/ses/provider.py +++ b/localstack/services/ses/provider.py @@ -50,12 +50,11 @@ VerificationAttributes, VerificationStatus, ) -from localstack.constants import TEST_AWS_SECRET_ACCESS_KEY +from localstack.aws.connect import connect_to from localstack.services.internal import get_internal_apis from localstack.services.moto import call_moto from localstack.services.plugins import ServiceLifecycleHook from localstack.services.ses.models import SentEmail, SentEmailBody -from localstack.utils.aws import arns, aws_stack from localstack.utils.files import mkdir from localstack.utils.strings import long_uid, to_str from localstack.utils.time import timestamp, timestamp_millis @@ -577,13 +576,4 @@ def emit_delivery_event(self, payload: SNSPayload, sns_topic_arn: str): @staticmethod def _client_for_topic(topic_arn: str) -> "SNSClient": - arn_parameters = arns.parse_arn(topic_arn) - region = arn_parameters["region"] - access_key_id = arn_parameters["account"] - - return aws_stack.connect_to_service( - "sns", - region_name=region, - aws_access_key_id=access_key_id, - aws_secret_access_key=TEST_AWS_SECRET_ACCESS_KEY, - ) + return connect_to("sns", target_arn=topic_arn, source_service="ses") diff --git a/localstack/services/sns/models.py b/localstack/services/sns/models.py index eac0054eda42d..e8eadf2b8e8b8 100644 --- a/localstack/services/sns/models.py +++ b/localstack/services/sns/models.py @@ -7,7 +7,12 @@ subscriptionARN, topicARN, ) -from localstack.services.stores import AccountRegionBundle, BaseStore, LocalAttribute +from localstack.services.stores import ( + AccountRegionBundle, + BaseStore, + CrossAccountAttribute, + LocalAttribute, +) from localstack.utils.strings import long_uid SnsProtocols = Literal[ @@ -89,22 +94,22 @@ class SnsSubscription(TypedDict): class SnsStore(BaseStore): # maps topic ARN to topic's subscriptions - sns_subscriptions: Dict[str, List[SnsSubscription]] = LocalAttribute(default=dict) + SNS_SUBSCRIPTIONS: Dict[topicARN, List[SnsSubscription]] = CrossAccountAttribute(default=dict) # maps subscription ARN to subscription status - subscription_status: Dict[str, Dict] = LocalAttribute(default=dict) + SUBSCRIPTION_STATUS: Dict[topicARN, Dict] = CrossAccountAttribute(default=dict) # maps topic ARN to list of tags - sns_tags: Dict[str, List[Dict]] = LocalAttribute(default=dict) + SNS_TAGS: Dict[topicARN, List[Dict]] = CrossAccountAttribute(default=dict) + + # filter policy are stored as JSON string in subscriptions, store the decoded result Dict + SUBSCRIPTION_FILTER_POLICY: Dict[subscriptionARN, Dict] = CrossAccountAttribute(default=dict) # cache of topic ARN to platform endpoint messages (used primarily for testing) - platform_endpoint_messages: Dict[str, List[Dict]] = LocalAttribute(default=dict) + platform_endpoint_messages: Dict[topicARN, List[Dict]] = LocalAttribute(default=dict) # list of sent SMS messages - TODO: expose via internal API sms_messages: List[Dict] = LocalAttribute(default=list) - # filter policy are stored as JSON string in subscriptions, store the decoded result Dict - subscription_filter_policy: Dict[subscriptionARN, Dict] = LocalAttribute(default=dict) - sns_stores = AccountRegionBundle("sns", SnsStore) diff --git a/localstack/services/sns/provider.py b/localstack/services/sns/provider.py index 1c4445ea3ad79..aea575cee9bb7 100644 --- a/localstack/services/sns/provider.py +++ b/localstack/services/sns/provider.py @@ -1,5 +1,7 @@ import json import logging +import re +from contextlib import contextmanager from typing import Dict, List from botocore.utils import InvalidArnException @@ -74,6 +76,7 @@ topicARN, topicName, ) +from localstack.constants import AUTH_CREDENTIAL_REGEX from localstack.http import Request, Response, Router, route from localstack.services.edge import ROUTER from localstack.services.moto import call_moto @@ -86,7 +89,11 @@ SnsPublishContext, ) from localstack.utils.aws import aws_stack -from localstack.utils.aws.arns import parse_arn +from localstack.utils.aws.arns import ( + extract_account_id_from_arn, + extract_region_from_arn, + parse_arn, +) from localstack.utils.strings import short_uid # set up logger @@ -109,6 +116,38 @@ def on_after_init(self): def get_store() -> SnsStore: return sns_stores[get_aws_account_id()][aws_stack.get_region()] + @staticmethod + @contextmanager + def modify_context_with_arn(context: RequestContext, arn: str): + # TODO@viren this is pretty similar to DynamoDB's `modify_context_request()` + # TODO@viren maybe we can make this `call_moto_with_arn_context` + original_account_id = context.account_id + original_region = context.region + original_auth_header = context.request.headers.get("Authorization") + + arn_region = extract_region_from_arn(arn) + context.region = arn_region + + arn_account_id = extract_account_id_from_arn(arn) + context.account_id = arn_account_id + + context.request.headers["Authorization"] = re.sub( + AUTH_CREDENTIAL_REGEX, + rf"Credential={arn_account_id}/\2/{arn_region}/\4/", + original_auth_header or "", + flags=re.IGNORECASE, + ) + + # TODO@viren:THis header is currently added in handler chain + # consider moving it to `call_moto()` + context.request.headers["x-moto-account-id"] = arn_account_id + + yield context + + context.account_id = original_account_id + context.region = original_region + context.request.headers["Authorization"] = original_auth_header + def add_permission( self, context: RequestContext, @@ -247,7 +286,8 @@ def set_topic_attributes( attribute_name: attributeName, attribute_value: attributeValue = None, ) -> None: - call_moto(context) + with self.modify_context_with_arn(context, topic_arn): + call_moto(context) def verify_sms_sandbox_phone_number( self, context: RequestContext, phone_number: PhoneNumberString, one_time_password: OTPCode @@ -258,7 +298,8 @@ def verify_sms_sandbox_phone_number( def get_topic_attributes( self, context: RequestContext, topic_arn: topicARN ) -> GetTopicAttributesResponse: - moto_response = call_moto(context) + with self.modify_context_with_arn(context, topic_arn): + moto_response = call_moto(context) # todo fix some attributes by moto, see snapshot return GetTopicAttributesResponse(**moto_response) @@ -274,7 +315,7 @@ def publish_batch( ) store = self.get_store() - if topic_arn not in store.sns_subscriptions: + if topic_arn not in store.SNS_SUBSCRIPTIONS: raise NotFoundException( "Topic does not exist", ) @@ -365,7 +406,7 @@ def set_subscription_attributes( if attribute_name == "FilterPolicy": store = self.get_store() - store.subscription_filter_policy[subscription_arn] = json.loads(attribute_value) + store.SUBSCRIPTION_FILTER_POLICY[subscription_arn] = json.loads(attribute_value) sub[attribute_name] = attribute_value @@ -380,11 +421,11 @@ def confirm_subscription( sub_arn = None # TODO: this is false, we validate only one sub and not all for topic # WRITE AWS VALIDATED TEST FOR IT - for k, v in store.subscription_status.items(): + for k, v in store.SUBSCRIPTION_STATUS.items(): if v.get("Token") == token and v["TopicArn"] == topic_arn: v["Status"] = "Subscribed" sub_arn = k - for k, v in store.sns_subscriptions.items(): + for k, v in store.SNS_SUBSCRIPTIONS.items(): for i in v: if i["TopicArn"] == topic_arn: i["PendingConfirmation"] = "false" @@ -397,7 +438,7 @@ def untag_resource( ) -> UntagResourceResponse: call_moto(context) store = self.get_store() - store.sns_tags[resource_arn] = [ + store.SNS_TAGS[resource_arn] = [ t for t in _get_tags(resource_arn) if t["Key"] not in tag_keys ] return UntagResourceResponse() @@ -481,8 +522,8 @@ def should_be_kept(current_subscription: SnsSubscription, target_subscription_ar return False - for topic_arn, existing_subs in store.sns_subscriptions.items(): - store.sns_subscriptions[topic_arn] = [ + for topic_arn, existing_subs in store.SNS_SUBSCRIPTIONS.items(): + store.SNS_SUBSCRIPTIONS[topic_arn] = [ sub for sub in existing_subs if should_be_kept(sub, subscription_arn) ] @@ -582,7 +623,7 @@ def publish( ) else: topic = topic_arn or target_arn - if topic not in store.sns_subscriptions: + if topic not in store.SNS_SUBSCRIPTIONS: raise NotFoundException( "Topic does not exist", ) @@ -661,8 +702,8 @@ def subscribe( subscription_arn = moto_response.get("SubscriptionArn") store = self.get_store() - topic_subs = store.sns_subscriptions[topic_arn] = ( - store.sns_subscriptions.get(topic_arn) or [] + topic_subs = store.SNS_SUBSCRIPTIONS[topic_arn] = ( + store.SNS_SUBSCRIPTIONS.get(topic_arn) or [] ) # An endpoint may only be subscribed to a topic once. Subsequent # subscribe calls do nothing (subscribe is idempotent). @@ -685,17 +726,17 @@ def subscribe( if attributes: subscription.update(attributes) if "FilterPolicy" in attributes: - store.subscription_filter_policy[subscription_arn] = json.loads( + store.SUBSCRIPTION_FILTER_POLICY[subscription_arn] = json.loads( attributes["FilterPolicy"] ) topic_subs.append(subscription) - if subscription_arn not in store.subscription_status: - store.subscription_status[subscription_arn] = {} + if subscription_arn not in store.SUBSCRIPTION_STATUS: + store.SUBSCRIPTION_STATUS[subscription_arn] = {} subscription_token = short_uid() - store.subscription_status[subscription_arn].update( + store.SUBSCRIPTION_STATUS[subscription_arn].update( {"TopicArn": topic_arn, "Token": subscription_token, "Status": "Not Subscribed"} ) # Send out confirmation message for HTTP(S), fix for https://github.com/localstack/localstack/issues/881 @@ -730,7 +771,7 @@ def tag_resource( call_moto(context) store = self.get_store() - existing_tags = store.sns_tags.get(resource_arn, []) + existing_tags = store.SNS_TAGS.get(resource_arn, []) def existing_tag_index(item): for idx, tag in enumerate(existing_tags): @@ -745,14 +786,15 @@ def existing_tag_index(item): else: existing_tags[existing_index] = item - store.sns_tags[resource_arn] = existing_tags + store.SNS_TAGS[resource_arn] = existing_tags return TagResourceResponse() def delete_topic(self, context: RequestContext, topic_arn: topicARN) -> None: - call_moto(context) + with self.modify_context_with_arn(context, topic_arn): + call_moto(context) store = self.get_store() - store.sns_subscriptions.pop(topic_arn, None) - store.sns_tags.pop(topic_arn, None) + store.SNS_SUBSCRIPTIONS.pop(topic_arn, None) + store.SNS_TAGS.pop(topic_arn, None) def create_topic( self, @@ -772,7 +814,7 @@ def create_topic( ) if tags: self.tag_resource(context=context, resource_arn=topic_arn, tags=tags) - store.sns_subscriptions[topic_arn] = store.sns_subscriptions.get(topic_arn) or [] + store.SNS_SUBSCRIPTIONS[topic_arn] = store.SNS_SUBSCRIPTIONS.get(topic_arn) or [] return CreateTopicResponse(TopicArn=topic_arn) @@ -780,7 +822,7 @@ def get_subscription_by_arn(sub_arn): store = SnsProvider.get_store() # TODO maintain separate map instead of traversing all items # how to deprecate the store without breaking pods/persistence - for key, subscriptions in store.sns_subscriptions.items(): + for key, subscriptions in store.SNS_SUBSCRIPTIONS.items(): for sub in subscriptions: if sub["SubscriptionArn"] == sub_arn: return sub @@ -788,10 +830,10 @@ def get_subscription_by_arn(sub_arn): def _get_tags(topic_arn): store = SnsProvider.get_store() - if topic_arn not in store.sns_tags: - store.sns_tags[topic_arn] = [] + if topic_arn not in store.SNS_TAGS: + store.SNS_TAGS[topic_arn] = [] - return store.sns_tags[topic_arn] + return store.SNS_TAGS[topic_arn] def is_raw_message_delivery(susbcriber): @@ -920,8 +962,8 @@ def validate_message_attribute_name(name: str) -> None: def extract_tags(topic_arn, tags, is_create_topic_request, store): - existing_tags = list(store.sns_tags.get(topic_arn, [])) - existing_sub = store.sns_subscriptions.get(topic_arn, None) + existing_tags = list(store.SNS_TAGS.get(topic_arn, [])) + existing_sub = store.SNS_SUBSCRIPTIONS.get(topic_arn, None) # if this is none there is nothing to check if existing_sub is not None: if tags is None: diff --git a/localstack/services/sns/publisher.py b/localstack/services/sns/publisher.py index eaf970ceed5ba..e8627345a3f0b 100644 --- a/localstack/services/sns/publisher.py +++ b/localstack/services/sns/publisher.py @@ -1055,7 +1055,7 @@ def _should_publish( Validate that the message should be relayed to the subscriber, depending on the filter policy """ subscriber_arn = subscriber["SubscriptionArn"] - filter_policy = store.subscription_filter_policy.get(subscriber_arn) + filter_policy = store.SUBSCRIPTION_FILTER_POLICY.get(subscriber_arn) if not filter_policy: return True # default value is `MessageAttributes` @@ -1071,7 +1071,7 @@ def _should_publish( ) def publish_to_topic(self, ctx: SnsPublishContext, topic_arn: str) -> None: - subscriptions = ctx.store.sns_subscriptions.get(topic_arn, []) + subscriptions = ctx.store.SNS_SUBSCRIPTIONS.get(topic_arn, []) for subscriber in subscriptions: if self._should_publish(ctx.store, ctx.message, subscriber): notifier = self.topic_notifiers[subscriber["Protocol"]] @@ -1086,7 +1086,7 @@ def publish_to_topic(self, ctx: SnsPublishContext, topic_arn: str) -> None: self.executor.submit(notifier.publish, context=ctx, subscriber=subscriber) def publish_batch_to_topic(self, ctx: SnsBatchPublishContext, topic_arn: str) -> None: - subscriptions = ctx.store.sns_subscriptions.get(topic_arn, []) + subscriptions = ctx.store.SNS_SUBSCRIPTIONS.get(topic_arn, []) for subscriber in subscriptions: protocol = subscriber["Protocol"] notifier = self.batch_topic_notifiers.get(protocol) @@ -1174,7 +1174,7 @@ def publish_to_topic_subscriber( :param subscription_arn: the ARN of the subscriber :return: None """ - subscriptions: List[SnsSubscription] = ctx.store.sns_subscriptions.get(topic_arn, []) + subscriptions: List[SnsSubscription] = ctx.store.SNS_SUBSCRIPTIONS.get(topic_arn, []) for subscriber in subscriptions: if subscriber["SubscriptionArn"] == subscription_arn: notifier = self.topic_notifiers[subscriber["Protocol"]] diff --git a/localstack/testing/pytest/fixtures.py b/localstack/testing/pytest/fixtures.py index 47d5d98a8591d..af20f6592df23 100644 --- a/localstack/testing/pytest/fixtures.py +++ b/localstack/testing/pytest/fixtures.py @@ -135,6 +135,14 @@ def _resource(service): return aws_stack.connect_to_resource_external(service, config=config) +@pytest.fixture +def client_factory(): + def _client_factory(service: str, aws_access_key_id: str, region_name: str = "eu-central-1"): + return _client(service, region_name=region_name, aws_access_key_id=aws_access_key_id) + + yield _client_factory + + @pytest.fixture(scope="class") def create_boto_client(): return _client diff --git a/localstack/utils/aws/dead_letter_queue.py b/localstack/utils/aws/dead_letter_queue.py index a0afa9162b433..a3376cfe7bafa 100644 --- a/localstack/utils/aws/dead_letter_queue.py +++ b/localstack/utils/aws/dead_letter_queue.py @@ -3,6 +3,7 @@ import uuid from typing import Dict, List +from localstack.aws.connect import connect_to from localstack.utils.aws import arns, aws_stack from localstack.utils.aws.aws_models import LambdaFunction from localstack.utils.strings import convert_to_printable_chars, first_char_to_upper @@ -52,7 +53,9 @@ def _send_to_dead_letter_queue(source_arn: str, dlq_arn: str, event: Dict, error LOG.info(msg) raise Exception(msg) elif ":sns:" in dlq_arn: - sns_client = aws_stack.connect_to_service("sns") + sns_client = connect_to( + "sns", target_arn=dlq_arn, source_service="lambda", source_arn=source_arn + ) for message in messages: sns_client.publish( TopicArn=dlq_arn, diff --git a/localstack/utils/aws/message_forwarding.py b/localstack/utils/aws/message_forwarding.py index ab9f7f809db41..0f38096e63a74 100644 --- a/localstack/utils/aws/message_forwarding.py +++ b/localstack/utils/aws/message_forwarding.py @@ -7,6 +7,7 @@ from moto.events.models import events_backends +from localstack.aws.connect import connect_to from localstack.services.apigateway.helpers import extract_query_string_params from localstack.utils import collections from localstack.utils.aws.arns import ( @@ -35,6 +36,7 @@ def send_event_to_target( asynchronous: bool = True, target: Dict = None, ): + # TODO@viren Refactor to accept source ARN and source service, and send them with all `connect_to` calls region = extract_region_from_arn(target_arn) if target is None: target = {} @@ -48,7 +50,7 @@ def send_event_to_target( ) elif ":sns:" in target_arn: - sns_client = connect_to_service("sns", region_name=region) + sns_client = connect_to("sns", target_arn=target_arn) sns_client.publish(TopicArn=target_arn, Message=json.dumps(event)) elif ":sqs:" in target_arn: diff --git a/tests/integration/test_multi_accounts.py b/tests/integration/test_multi_accounts.py index cf1634d50cb07..0437e39fbcf38 100644 --- a/tests/integration/test_multi_accounts.py +++ b/tests/integration/test_multi_accounts.py @@ -1,17 +1,6 @@ -import pytest - -from localstack.testing.pytest.fixtures import _client from localstack.utils.strings import short_uid -@pytest.fixture -def client_factory(): - def _client_factory(service: str, aws_access_key_id: str, region_name: str = "eu-central-1"): - return _client(service, region_name=region_name, aws_access_key_id=aws_access_key_id) - - yield _client_factory - - class TestMultiAccounts: def test_account_id_namespacing_for_moto_backends(self, client_factory): # diff --git a/tests/integration/test_sns.py b/tests/integration/test_sns.py index b80d6a34c04f2..650c1b3636a01 100644 --- a/tests/integration/test_sns.py +++ b/tests/integration/test_sns.py @@ -594,7 +594,7 @@ def test_topic_subscription(self, sns_client, sns_create_topic, sns_subscription def check_subscription(): subscription_arn = subscription["SubscriptionArn"] - subscription_obj = sns_backend.subscription_status[subscription_arn] + subscription_obj = sns_backend.SUBSCRIPTION_STATUS[subscription_arn] assert subscription_obj["Status"] == "Not Subscribed" _token = subscription_obj["Token"] @@ -2911,6 +2911,21 @@ def test_message_structure_json_exc(self, sns_client, sns_create_topic, snapshot ) snapshot.match("duplicate-json-keys", resp) + def test_sns_cross_account_access(self, client_factory): + # Ensure a topic can be retrieved from any region and account ID + sns_client1 = client_factory( + "sns", aws_access_key_id="424242424242", region_name="eu-central-1" + ) + sns_client2 = client_factory( + "sns", aws_access_key_id="100010001000", region_name="ap-south-1" + ) + + topic_name = f"topic-{short_uid()}" + + topic_arn = sns_client1.create_topic(Name=topic_name)["TopicArn"] + + assert sns_client2.get_topic_attributes(TopicArn=topic_arn) + @pytest.mark.aws_validated @pytest.mark.skip_snapshot_verify(paths=["$..Attributes.SubscriptionPrincipal"]) def test_set_subscription_filter_policy_scope(