From add79288ba47407ec9a916a43358b802f9bc7def Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Thu, 24 Nov 2022 16:27:03 +0100 Subject: [PATCH 01/18] add mob-programming based client prototype --- localstack/aws/client.py | 114 +++++++++++++++++- .../sqs_event_source_listener.py | 10 ++ 2 files changed, 123 insertions(+), 1 deletion(-) diff --git a/localstack/aws/client.py b/localstack/aws/client.py index a4431f0bcd1ff..cf7e24ab5b653 100644 --- a/localstack/aws/client.py +++ b/localstack/aws/client.py @@ -1,17 +1,27 @@ """Utils to process AWS requests as a client.""" +import dataclasses import io +import json import logging from datetime import datetime -from typing import Dict, Iterable, Optional +from typing import TYPE_CHECKING, Dict, Iterable, Optional +from boto3 import Session +from botocore.awsrequest import AWSPreparedRequest +from botocore.client import BaseClient +from botocore.config import Config from botocore.model import OperationModel from botocore.parsers import ResponseParser, ResponseParserFactory from werkzeug import Response from localstack.aws.api import CommonServiceException, ServiceException, ServiceResponse from localstack.runtime import hooks +from localstack.utils.aws.aws_stack import extract_region_from_arn from localstack.utils.patch import patch +if TYPE_CHECKING: + from mypy_boto3_sqs import SQSClient + LOG = logging.getLogger(__name__) @@ -220,3 +230,105 @@ def raise_service_exception(response: Response, parsed_response: Dict) -> None: """ if service_exception := parse_service_exception(response, parsed_response): raise service_exception + + +@dataclasses.dataclass(frozen=True) +class ClientOptions: + aws_region: Optional[str] = None + endpoint_url: Optional[str] = None + verify_ssl: Optional[bool] = True + use_ssl: Optional[bool] = True + aws_access_key_id: Optional[str] = None + aws_secret_access_key: Optional[str] = None + aws_session_token: Optional[str] = None + boto_config: Optional[Config] = dataclasses.field(default_factory=Config) + # TODO typing + localstack_data: dict[str, str] = dataclasses.field(default_factory=dict) + + +class ClientFactory: + """ + Client Factory to build boto clients for AWS services. + + Usage: + + """ + + # TODO migrate to immutable clientfactory instances + client_options: ClientOptions + session: Session + + def __init__(self, client_options=None): + self.client_options = client_options or ClientOptions() + self.session = Session() + + def endpoint(self, endpoint: str) -> "ClientFactory": + return ClientFactory( + client_options=dataclasses.replace(self.client_options, endpoint_url=endpoint) + ) + + def source_arn(self, arn: str) -> "ClientFactory": + return ClientFactory( + client_options=dataclasses.replace( + self.client_options, + localstack_data=self.client_options.localstack_data | {"source_arn": arn}, + ) + ) + + def target_arn(self, arn: str) -> "ClientFactory": + region = extract_region_from_arn(arn) + return ClientFactory( + client_options=dataclasses.replace(self.client_options, aws_region=region) + ) + + def source_service_principal(self, source_service: str) -> "ClientFactory": + return ClientFactory( + client_options=dataclasses.replace( + self.client_options, + localstack_data=self.client_options.localstack_data + | {"source_service": f"{source_service}.amazonaws.com"}, + ) + ) + + def credentials(self, credentials: Dict[str, str]) -> "ClientFactory": + return ClientFactory(client_options=dataclasses.replace(self.client_options, **credentials)) + + def default_credentials(self) -> "ClientFactory": + return self.credentials( + {"aws_access_key_id": "some-access-key-id", "aws_secret_access_key": "some-secret-key"} + ) + + def environment_credentials(self) -> "ClientFactory": + # TODO wrong output format of session.get_credentials() + return self.credentials(self.session.get_credentials()) + + def boto_config(self, config: Config) -> "ClientFactory": + return ClientFactory( + client_options=dataclasses.replace( + self.client_options, boto_config=self.client_options.boto_config.merge(config) + ) + ) + + def build(self, service: str) -> BaseClient: + assert self.client_options.aws_access_key_id + assert self.client_options.aws_secret_access_key + # TODO: performance :( + client = self.session.client( + service_name=service, + config=self.client_options.boto_config, + aws_access_key_id=self.client_options.aws_access_key_id, + ) + + def event_handler(request: AWSPreparedRequest, **_): + request.headers["x-localstack-data"] = json.dumps(self.client_options.localstack_data) + + client.meta.events.register("before-send.*.*", handler=event_handler) + + return client + + def sqs(self) -> "SQSClient": + return self.build("sqs") + + +def aws_client(): + return ClientFactory() diff --git a/localstack/services/awslambda/event_source_listeners/sqs_event_source_listener.py b/localstack/services/awslambda/event_source_listeners/sqs_event_source_listener.py index 7ee7a9b9ad929..0a5cf3cdfd2ef 100644 --- a/localstack/services/awslambda/event_source_listeners/sqs_event_source_listener.py +++ b/localstack/services/awslambda/event_source_listeners/sqs_event_source_listener.py @@ -63,6 +63,16 @@ def _listener_loop(self, *args): for source in sources: queue_arn = source["EventSourceArn"] region_name = extract_region_from_arn(queue_arn) + ## sqs_client = aws_stack.connect_to_service("sqs", region_name=region_name) + # `sqs` is the factory + # config args include stuff going into boto like retries, max conn pool,. + # could be called `boto_config` + # all options will map to boto_config args + # aws.sqs.configure().credentials(aws_access_key_id=...) + # client can be created from ARNs + # can communicate with Queue ARN + # all methods could be prefixed with `set_` eg. set_target_arn() + # sqs_client = aws_client().target_arn(queue_arn).credentials(sts_call_credentials).sqs() sqs_client = aws_stack.connect_to_service("sqs", region_name=region_name) batch_size = max(min(source.get("BatchSize", 1), 10), 1) From f57429dff054efb1ceb359dce5bd02f2dfbaa746 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Mon, 2 Jan 2023 19:58:54 +0530 Subject: [PATCH 02/18] WIP --- localstack/aws/client.py | 109 ++++++++++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 29 deletions(-) diff --git a/localstack/aws/client.py b/localstack/aws/client.py index cf7e24ab5b653..c1fbcd016132d 100644 --- a/localstack/aws/client.py +++ b/localstack/aws/client.py @@ -4,17 +4,18 @@ import json import logging from datetime import datetime -from typing import TYPE_CHECKING, Dict, Iterable, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional from boto3 import Session from botocore.awsrequest import AWSPreparedRequest from botocore.client import BaseClient -from botocore.config import Config +from botocore.config import Config as BotoConfig from botocore.model import OperationModel from botocore.parsers import ResponseParser, ResponseParserFactory from werkzeug import Response from localstack.aws.api import CommonServiceException, ServiceException, ServiceResponse +from localstack.constants import INTERNAL_AWS_ACCESS_KEY_ID, INTERNAL_AWS_SECRET_ACCESS_KEY from localstack.runtime import hooks from localstack.utils.aws.aws_stack import extract_region_from_arn from localstack.utils.patch import patch @@ -232,77 +233,116 @@ def raise_service_exception(response: Response, parsed_response: Dict) -> None: raise service_exception +# +# Internal AWS client +# + +""" +The internal AWS client API provides the means to perform cross-service communication within LocalStack. +Any additional information LocalStack might need for the purpose of policy enforcement is sent as a +data transfer object. This is a serialised dict object sent in the request header. +""" + +LOCALSTACK_DATA_HEADER = "x-localstack-data" +"""Request header which contains the data transfer object.""" + + +def LocalStackData(TypedDict): + source_arn: str + source_service: str # eg. 'ec2.amazonaws.com' + + +def Credentials(TypedDict): + aws_access_key_id: str + aws_secret_access_key: str + aws_session_token: str + + @dataclasses.dataclass(frozen=True) class ClientOptions: + """This object holds configuration options for the internal AWS client.""" + aws_region: Optional[str] = None - endpoint_url: Optional[str] = None - verify_ssl: Optional[bool] = True - use_ssl: Optional[bool] = True + endpoint_url: Optional[str] = None # TODO@viren should the default endpoint be used here? + verify_ssl: bool = True + use_ssl: bool = True aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None aws_session_token: Optional[str] = None - boto_config: Optional[Config] = dataclasses.field(default_factory=Config) - # TODO typing - localstack_data: dict[str, str] = dataclasses.field(default_factory=dict) + boto_config: Optional[BotoConfig] = dataclasses.field(default_factory=BotoConfig) + localstack_data: dict[str, Any] = dataclasses.field(default_factory=LocalStackData) class ClientFactory: - """ - Client Factory to build boto clients for AWS services. - - Usage: - - """ + """Factory to build the internal AWS client.""" # TODO migrate to immutable clientfactory instances client_options: ClientOptions session: Session - def __init__(self, client_options=None): + def __init__(self, client_options: ClientOptions = None): self.client_options = client_options or ClientOptions() self.session = Session() - def endpoint(self, endpoint: str) -> "ClientFactory": + def with_endpoint(self, endpoint: str) -> "ClientFactory": + """Override the API endpoint.""" return ClientFactory( client_options=dataclasses.replace(self.client_options, endpoint_url=endpoint) ) - def source_arn(self, arn: str) -> "ClientFactory": + def with_source_arn(self, arn: str) -> "ClientFactory": + """TODO""" return ClientFactory( client_options=dataclasses.replace( self.client_options, - localstack_data=self.client_options.localstack_data | {"source_arn": arn}, + localstack_data=self.client_options.localstack_data + | LocalStackData(source_arn=arn), ) ) - def target_arn(self, arn: str) -> "ClientFactory": + def with_target_arn(self, arn: str) -> "ClientFactory": + """TODO""" region = extract_region_from_arn(arn) return ClientFactory( client_options=dataclasses.replace(self.client_options, aws_region=region) ) - def source_service_principal(self, source_service: str) -> "ClientFactory": + def with_source_service_principal(self, source_service: str) -> "ClientFactory": + """TODO""" return ClientFactory( client_options=dataclasses.replace( self.client_options, localstack_data=self.client_options.localstack_data - | {"source_service": f"{source_service}.amazonaws.com"}, + | LocalStackData(source_service=f"{source_service}.amazonaws.com"), ) ) - def credentials(self, credentials: Dict[str, str]) -> "ClientFactory": - return ClientFactory(client_options=dataclasses.replace(self.client_options, **credentials)) + def with_credentials( + self, aws_access_key_id: str, aws_secret_access_key: str + ) -> "ClientFactory": + """TODO""" + return ClientFactory( + client_options=dataclasses.replace( + self.client_options, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + ) - def default_credentials(self) -> "ClientFactory": + def with_default_credentials(self) -> "ClientFactory": + """TODO""" return self.credentials( - {"aws_access_key_id": "some-access-key-id", "aws_secret_access_key": "some-secret-key"} + aws_access_key_id=INTERNAL_AWS_ACCESS_KEY_ID, + aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY, ) - def environment_credentials(self) -> "ClientFactory": + def with_env_credentials(self) -> "ClientFactory": + """TODO""" # TODO wrong output format of session.get_credentials() return self.credentials(self.session.get_credentials()) - def boto_config(self, config: Config) -> "ClientFactory": + def with_boto_config(self, config: BotoConfig) -> "ClientFactory": + """TODO""" return ClientFactory( client_options=dataclasses.replace( self.client_options, boto_config=self.client_options.boto_config.merge(config) @@ -310,9 +350,13 @@ def boto_config(self, config: Config) -> "ClientFactory": ) def build(self, service: str) -> BaseClient: + """TODO""" assert self.client_options.aws_access_key_id assert self.client_options.aws_secret_access_key - # TODO: performance :( + + # TODO: creating a boto client is very intensive. In old aws_stack, we cache clients based on + # [service_name, client, env, region, endpoint_url, config, internal, kwargs] + # Come up with an appropriate solution here client = self.session.client( service_name=service, config=self.client_options.boto_config, @@ -320,12 +364,19 @@ def build(self, service: str) -> BaseClient: ) def event_handler(request: AWSPreparedRequest, **_): - request.headers["x-localstack-data"] = json.dumps(self.client_options.localstack_data) + # Send a compact JSON representation as DTO + request.headers[LOCALSTACK_DATA_HEADER] = json.dumps( + self.client_options.localstack_data, separators=(",", ":") + ) client.meta.events.register("before-send.*.*", handler=event_handler) return client + # + # Convenience helpers + # + def sqs(self) -> "SQSClient": return self.build("sqs") From 89e43f5d160f79bfd93d63a8f8ff9e44d701d002 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Wed, 4 Jan 2023 12:08:41 +0530 Subject: [PATCH 03/18] Fix imports --- localstack/aws/client.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/localstack/aws/client.py b/localstack/aws/client.py index 1961294d4ccb9..d3db21a412f9a 100644 --- a/localstack/aws/client.py +++ b/localstack/aws/client.py @@ -17,7 +17,7 @@ from localstack.aws.api import CommonServiceException, ServiceException, ServiceResponse from localstack.constants import INTERNAL_AWS_ACCESS_KEY_ID, INTERNAL_AWS_SECRET_ACCESS_KEY from localstack.runtime import hooks -from localstack.utils.aws.aws_stack import extract_region_from_arn +from localstack.utils.aws.arns import extract_region_from_arn from localstack.utils.patch import patch if TYPE_CHECKING: @@ -238,9 +238,12 @@ def raise_service_exception(response: Response, parsed_response: Dict) -> None: # """ -The internal AWS client API provides the means to perform cross-service communication within LocalStack. -Any additional information LocalStack might need for the purpose of policy enforcement is sent as a -data transfer object. This is a serialised dict object sent in the request header. +The internal AWS client API provides the means to perform cross-service +communication within LocalStack. + +Any additional information LocalStack might need for the purpose of policy +enforcement is sent as a data transfer object. This is a serialised dict object +sent in the request header. """ LOCALSTACK_DATA_HEADER = "x-localstack-data" From 0352fae468e6761db26660f809c5610644434e8e Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Fri, 6 Jan 2023 18:02:08 +0530 Subject: [PATCH 04/18] Updates --- localstack/aws/client.py | 100 ++++++++++++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 22 deletions(-) diff --git a/localstack/aws/client.py b/localstack/aws/client.py index d3db21a412f9a..c8c0a3fb866a4 100644 --- a/localstack/aws/client.py +++ b/localstack/aws/client.py @@ -4,7 +4,7 @@ import json import logging from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, TypedDict from boto3 import Session from botocore.awsrequest import AWSPreparedRequest @@ -18,6 +18,7 @@ from localstack.constants import INTERNAL_AWS_ACCESS_KEY_ID, INTERNAL_AWS_SECRET_ACCESS_KEY from localstack.runtime import hooks from localstack.utils.aws.arns import extract_region_from_arn +from localstack.utils.aws.aws_stack import get_local_service_url from localstack.utils.patch import patch if TYPE_CHECKING: @@ -250,12 +251,20 @@ def raise_service_exception(response: Response, parsed_response: Dict) -> None: """Request header which contains the data transfer object.""" -def LocalStackData(TypedDict): +class LocalStackData(TypedDict): + """ + LocalStack Data Transfer Object. + """ + source_arn: str source_service: str # eg. 'ec2.amazonaws.com' -def Credentials(TypedDict): +class Credentials(TypedDict): + """ + AWS credentials. + """ + aws_access_key_id: str aws_secret_access_key: str aws_session_token: str @@ -265,19 +274,38 @@ def Credentials(TypedDict): class ClientOptions: """This object holds configuration options for the internal AWS client.""" - aws_region: Optional[str] = None - endpoint_url: Optional[str] = None # TODO@viren should the default endpoint be used here? - verify_ssl: bool = True + region_name: Optional[str] = None + """Name of the AWS region to be associated with the client.""" + + endpoint_url: Optional[str] = None + """Full endpoint URL to be used by the client.""" + use_ssl: bool = True + """Whether or not to use SSL.""" + + verify: bool = True + """Whether or not to verify SSL certificates.""" + aws_access_key_id: Optional[str] = None + """Access key to use for the client.""" + aws_secret_access_key: Optional[str] = None + """Secret key to use for the client.""" + aws_session_token: Optional[str] = None + """Session token to use for the client""" + boto_config: Optional[BotoConfig] = dataclasses.field(default_factory=BotoConfig) + """Boto client configuration for advanced use.""" + localstack_data: dict[str, Any] = dataclasses.field(default_factory=LocalStackData) + """LocalStack data transfer object.""" class ClientFactory: - """Factory to build the internal AWS client.""" + """ + Factory to build the internal AWS client. + """ # TODO migrate to immutable clientfactory instances client_options: ClientOptions @@ -288,13 +316,19 @@ def __init__(self, client_options: ClientOptions = None): self.session = Session() def with_endpoint(self, endpoint: str) -> "ClientFactory": - """Override the API endpoint.""" + """ + Set a custom endpoint. + """ return ClientFactory( client_options=dataclasses.replace(self.client_options, endpoint_url=endpoint) ) def with_source_arn(self, arn: str) -> "ClientFactory": - """TODO""" + """ + Indicate that the client is operating from a given resource. + + This must be used in cross-service requests. + """ return ClientFactory( client_options=dataclasses.replace( self.client_options, @@ -304,14 +338,22 @@ def with_source_arn(self, arn: str) -> "ClientFactory": ) def with_target_arn(self, arn: str) -> "ClientFactory": - """TODO""" - region = extract_region_from_arn(arn) + """ + Create the client to operate on a target resource. + + This must be used in cross-service requests. + """ + region_name = extract_region_from_arn(arn) return ClientFactory( - client_options=dataclasses.replace(self.client_options, aws_region=region) + client_options=dataclasses.replace(self.client_options, region_name=region_name) ) def with_source_service_principal(self, source_service: str) -> "ClientFactory": - """TODO""" + """ + Set the source service principal. + + This must be used in cross-service requests. + """ return ClientFactory( client_options=dataclasses.replace( self.client_options, @@ -323,7 +365,9 @@ def with_source_service_principal(self, source_service: str) -> "ClientFactory": def with_credentials( self, aws_access_key_id: str, aws_secret_access_key: str ) -> "ClientFactory": - """TODO""" + """ + Use custom AWS credentials. + """ return ClientFactory( client_options=dataclasses.replace( self.client_options, @@ -333,19 +377,25 @@ def with_credentials( ) def with_default_credentials(self) -> "ClientFactory": - """TODO""" - return self.credentials( + """ + Use LocalStack default AWS credentials. + """ + return self.with_credentials( aws_access_key_id=INTERNAL_AWS_ACCESS_KEY_ID, aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY, ) def with_env_credentials(self) -> "ClientFactory": - """TODO""" + """ + Use AWS credentials from the environment. + """ # TODO wrong output format of session.get_credentials() return self.credentials(self.session.get_credentials()) def with_boto_config(self, config: BotoConfig) -> "ClientFactory": - """TODO""" + """ + Use a custom BotoConfig. + """ return ClientFactory( client_options=dataclasses.replace( self.client_options, boto_config=self.client_options.boto_config.merge(config) @@ -353,17 +403,23 @@ def with_boto_config(self, config: BotoConfig) -> "ClientFactory": ) def build(self, service: str) -> BaseClient: - """TODO""" - assert self.client_options.aws_access_key_id - assert self.client_options.aws_secret_access_key + """ + Finalise the client. + """ + assert self.client_options.aws_access_key_id, "Access key ID is not set" + assert self.client_options.aws_secret_access_key, "Secret access key is not set" + + endpoint_url = self.client_options.endpoint_url or get_local_service_url(service) - # TODO: creating a boto client is very intensive. In old aws_stack, we cache clients based on + # TODO@viren: creating a boto client is very intensive. In old aws_stack, we cache clients based on # [service_name, client, env, region, endpoint_url, config, internal, kwargs] # Come up with an appropriate solution here client = self.session.client( service_name=service, config=self.client_options.boto_config, aws_access_key_id=self.client_options.aws_access_key_id, + aws_secret_access_key=self.client_options.aws_secret_access_key, + endpoint_url=endpoint_url, ) def event_handler(request: AWSPreparedRequest, **_): From 5a0ba6ce8750886d568cf36d8f5110aac36c376e Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Fri, 6 Jan 2023 20:04:01 +0530 Subject: [PATCH 05/18] Add core descriptor --- localstack/services/stores.py | 81 +++++++++++++++++++++++++-- localstack/testing/pytest/fixtures.py | 2 + 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/localstack/services/stores.py b/localstack/services/stores.py index 1f68724c837a4..1b1ad75b2273f 100644 --- a/localstack/services/stores.py +++ b/localstack/services/stores.py @@ -5,7 +5,7 @@ By convention, Stores are to be defined in `models` submodule of the service by subclassing BaseStore e.g. `localstack.services.sqs.models.SqsStore` -Also by convention, cross-region attributes are declared in CAPITAL_CASE +Also by convention, cross-region and cross-account attributes are declared in CAPITAL_CASE class SqsStore(BaseStore): queues: dict[str, SqsQueue] = LocalAttribute(default=dict) @@ -114,6 +114,47 @@ def _check_region_store_association(self, obj): ) +class CrossAccountAttribute: + """ + Descriptor protocol for marking a store attributes as shared across all regions and accounts. + + This should be used for resources that are identified by ARNs. + """ + + def __init__(self, default: Union[Callable, int, float, str, bool, None]): + """ + :param default: The default value assigned to the cross-account attribute. + This must be a scalar or a callable. + """ + self.default = default + + def __set_name__(self, owner, name): + self.name = name + + def __get__(self, obj: BaseStoreType, objtype=None) -> Any: + self._check_account_store_association(obj) + + if self.name not in obj._universal.keys(): + if isinstance(self.default, Callable): + obj._universal[self.name] = self.default() + else: + obj._universal[self.name] = self.default + + return obj._universal[self.name] + + def __set__(self, obj: BaseStoreType, value: Any): + self._check_account_store_association(obj) + + obj._universal[self.name] = value + + def _check_account_store_association(self, obj): + if not hasattr(obj, "_universal"): + # Raise if a Store is instantiated outside an AccountRegionBundle + raise AttributeError( + "Could not resolve cross-account attribute because there is no associated AccountRegionBundle" + ) + + # # Base models # @@ -124,6 +165,12 @@ class BaseStore: Base class for defining stores for LocalStack providers. """ + _service_name: str + _account_id: str + _region_name: str + _global: dict + _universal: dict + def __repr__(self): try: repr_templ = "<{name} object for {service_name} at {account_id}/{region_name}>" @@ -154,16 +201,20 @@ def __init__( account_id: str, validate: bool = True, lock: RLock = None, + universal: dict = None, ): self.store = store self.account_id = account_id self.service_name = service_name self.validate = validate self.lock = lock or RLock() + self._universal = universal self.valid_regions = get_valid_regions_for_service(service_name) - # Keeps track of all cross-region attributes + # Keeps track of all cross-region attributes. This dict is maintained at + # a region level (hence in RegionBundle). A ref is passed to every store + # intialised in this region so that backref is possible. self._global = {} def __getitem__(self, region_name) -> BaseStoreType: @@ -178,6 +229,7 @@ def __getitem__(self, region_name) -> BaseStoreType: store_obj = self.store() store_obj._global = self._global + store_obj._universal = self._universal store_obj._service_name = self.service_name store_obj._account_id = self.account_id store_obj._region_name = region_name @@ -186,8 +238,12 @@ def __getitem__(self, region_name) -> BaseStoreType: return super().__getitem__(region_name) - def reset(self): - """Clear all store data.""" + def reset(self, _reset_universal: bool = False): + """ + Clear all store data. + + This only deletes the data held in the stores. All instantiated stores are retained. + """ # For safety, clear data in all referenced store instances, if any for store_inst in self.values(): attrs = list(store_inst.__dict__.keys()) @@ -196,6 +252,9 @@ def reset(self): if attr == "_global": store_inst._global.clear() + if attr == "_universal" and _reset_universal: + store_inst._universal.clear() + # reset the local attributes elif attr.startswith(LOCAL_ATTR_PREFIX): delattr(store_inst, attr) @@ -222,6 +281,11 @@ def __init__(self, service_name: str, store: Type[BaseStoreType], validate: bool self.validate = validate self.lock = RLock() + # Keeps track of all cross-account attributes. This dict is maintained at + # the account level (hence in AccountRegionBundle). A ref is passed to + # every region bundle, which in turn passes it to every store in it. + self._universal = {} + def __getitem__(self, account_id: str) -> RegionBundle[BaseStoreType]: if self.validate and not re.match(r"\d{12}", account_id): raise ValueError(f"'{account_id}' is not a valid AWS account ID") @@ -234,15 +298,22 @@ def __getitem__(self, account_id: str) -> RegionBundle[BaseStoreType]: account_id=account_id, validate=self.validate, lock=self.lock, + universal=self._universal, ) return super().__getitem__(account_id) def reset(self): - """Clear all store data.""" + """ + Clear all store data. + + This only deletes the data held in the stores. All instantiated stores are retained. + """ # For safety, clear all referenced region bundles, if any for region_bundle in self.values(): region_bundle.reset() + self._universal.clear() + with self.lock: self.clear() diff --git a/localstack/testing/pytest/fixtures.py b/localstack/testing/pytest/fixtures.py index 6b2add1e12ec1..0f588631b27b3 100644 --- a/localstack/testing/pytest/fixtures.py +++ b/localstack/testing/pytest/fixtures.py @@ -26,6 +26,7 @@ from localstack.services.stores import ( AccountRegionBundle, BaseStore, + CrossAccountAttribute, CrossRegionAttribute, LocalAttribute, ) @@ -1916,6 +1917,7 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]): @pytest.fixture def sample_stores() -> AccountRegionBundle: class SampleStore(BaseStore): + CROSS_ACCOUNT_ATTR = CrossAccountAttribute(default=list) CROSS_REGION_ATTR = CrossRegionAttribute(default=list) region_specific_attr = LocalAttribute(default=list) From 3f399c6de6ebdb730948e25ab31861de2131971d Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Mon, 9 Jan 2023 12:38:47 +0530 Subject: [PATCH 06/18] Add unit tests --- localstack/services/stores.py | 2 +- tests/unit/test_stores.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/localstack/services/stores.py b/localstack/services/stores.py index 1b1ad75b2273f..25e161f8d261f 100644 --- a/localstack/services/stores.py +++ b/localstack/services/stores.py @@ -311,7 +311,7 @@ def reset(self): """ # For safety, clear all referenced region bundles, if any for region_bundle in self.values(): - region_bundle.reset() + region_bundle.reset(_reset_universal=True) self._universal.clear() diff --git a/tests/unit/test_stores.py b/tests/unit/test_stores.py index 8f5e73c9a5eb3..ee934428493b9 100644 --- a/tests/unit/test_stores.py +++ b/tests/unit/test_stores.py @@ -18,24 +18,45 @@ def test_store_reset(self, sample_stores): store1.region_specific_attr.extend([1, 2, 3]) store1.CROSS_REGION_ATTR.extend(["a", "b", "c"]) + store1.CROSS_ACCOUNT_ATTR.extend([100j, 200j, 300j]) store2.region_specific_attr.extend([4, 5, 6]) + store2.CROSS_ACCOUNT_ATTR.extend([400j]) store3.region_specific_attr.extend([7, 8, 9]) store3.CROSS_REGION_ATTR.extend([0.1, 0.2, 0.3]) + store3.CROSS_ACCOUNT_ATTR.extend([500j]) + + # Ensure all stores are affected by cross-account attributes + assert store1.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j, 400j, 500j] + assert store2.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j, 400j, 500j] + assert store3.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j, 400j, 500j] + + assert store1.CROSS_ACCOUNT_ATTR.pop() == 500j + + assert store2.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j, 400j] + assert store3.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j, 400j] # Ensure other account stores are not affected by RegionBundle reset + # Ensure cross-account attributes are not affected by RegionBundle reset sample_stores[account1].reset() assert store1.region_specific_attr == [] assert store1.CROSS_REGION_ATTR == [] + assert store1.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j, 400j] assert store2.region_specific_attr == [] assert store2.CROSS_REGION_ATTR == [] + assert store2.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j, 400j] assert store3.region_specific_attr == [7, 8, 9] assert store3.CROSS_REGION_ATTR == [0.1, 0.2, 0.3] + assert store3.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j, 400j] + # Ensure AccountRegionBundle reset sample_stores.reset() + assert store1.CROSS_ACCOUNT_ATTR == [] + assert store2.CROSS_ACCOUNT_ATTR == [] assert store3.region_specific_attr == [] assert store3.CROSS_REGION_ATTR == [] + assert store3.CROSS_ACCOUNT_ATTR == [] # Ensure essential properties are retained after reset assert store1._region_name == eu_region @@ -106,6 +127,19 @@ def test_store_namespacing(self, sample_stores): != id(backend1_ap._global) ) + # Ensure cross-account data sharing + backend1_eu.CROSS_ACCOUNT_ATTR.extend([100j, 200j, 300j]) + assert backend1_ap.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j] + assert backend1_eu.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j] + assert backend2_ap.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j] + assert backend2_eu.CROSS_ACCOUNT_ATTR == [100j, 200j, 300j] + assert ( + id(backend1_ap._universal) + == id(backend1_eu._universal) + == id(backend2_ap._universal) + == id(backend2_eu._universal) + ) + def test_valid_regions(self): class SampleStore(BaseStore): pass From e1a1d77220aa559fdd061523888bf1c6eeb44706 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Mon, 9 Jan 2023 14:58:04 +0530 Subject: [PATCH 07/18] Add owner for stores codebase --- CODEOWNERS | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CODEOWNERS b/CODEOWNERS index a127340666584..e3d1d1e72d2e1 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -30,6 +30,10 @@ # HTTP framework /localstack/http/ @thrau +# Stores +/localstack/services/stores.py @viren-nadkarni +/tests/unit/test_stores.py @viren-nadkarni + # Dockerfile /Dockerfile @alexrashed /Dockerfile.rh @alexrashed From 5f1eeb2cf54863ec9f9cbf2229ea55d0cd1ce09d Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Mon, 9 Jan 2023 18:33:13 +0530 Subject: [PATCH 08/18] Enable cross account access for SNS topics --- localstack/services/sns/models.py | 21 ++++--- localstack/services/sns/provider.py | 86 ++++++++++++++++++---------- localstack/services/sns/publisher.py | 8 +-- tests/integration/test_sns.py | 2 +- 4 files changed, 75 insertions(+), 42 deletions(-) diff --git a/localstack/services/sns/models.py b/localstack/services/sns/models.py index e3bfc20460fe9..803d7691617de 100644 --- a/localstack/services/sns/models.py +++ b/localstack/services/sns/models.py @@ -7,7 +7,12 @@ subscriptionARN, topicARN, ) -from localstack.services.stores import AccountRegionBundle, BaseStore, LocalAttribute +from localstack.services.stores import ( + AccountRegionBundle, + BaseStore, + CrossAccountAttribute, + LocalAttribute, +) from localstack.utils.strings import long_uid SnsProtocols = Literal[ @@ -87,22 +92,22 @@ class SnsSubscription(TypedDict): class SnsStore(BaseStore): # maps topic ARN to topic's subscriptions - sns_subscriptions: Dict[str, List[SnsSubscription]] = LocalAttribute(default=dict) + SNS_SUBSCRIPTIONS: Dict[topicARN, List[SnsSubscription]] = CrossAccountAttribute(default=dict) # maps subscription ARN to subscription status - subscription_status: Dict[str, Dict] = LocalAttribute(default=dict) + SUBSCRIPTION_STATUS: Dict[topicARN, Dict] = CrossAccountAttribute(default=dict) # maps topic ARN to list of tags - sns_tags: Dict[str, List[Dict]] = LocalAttribute(default=dict) + SNS_TAGS: Dict[topicARN, List[Dict]] = CrossAccountAttribute(default=dict) + + # filter policy are stored as JSON string in subscriptions, store the decoded result Dict + SUBSCRIPTION_FILTER_POLICY: Dict[subscriptionARN, Dict] = CrossAccountAttribute(default=dict) # cache of topic ARN to platform endpoint messages (used primarily for testing) - platform_endpoint_messages: Dict[str, List[Dict]] = LocalAttribute(default=dict) + platform_endpoint_messages: Dict[topicARN, List[Dict]] = LocalAttribute(default=dict) # list of sent SMS messages - TODO: expose via internal API sms_messages: List[Dict] = LocalAttribute(default=list) - # filter policy are stored as JSON string in subscriptions, store the decoded result Dict - subscription_filter_policy: Dict[subscriptionARN, Dict] = LocalAttribute(default=dict) - sns_stores = AccountRegionBundle("sns", SnsStore) diff --git a/localstack/services/sns/provider.py b/localstack/services/sns/provider.py index 05b8458302a67..1804dbe85287d 100644 --- a/localstack/services/sns/provider.py +++ b/localstack/services/sns/provider.py @@ -1,5 +1,7 @@ import json import logging +import re +from contextlib import contextmanager from typing import Dict, List from botocore.utils import InvalidArnException @@ -75,6 +77,7 @@ topicARN, topicName, ) +from localstack.constants import AUTH_CREDENTIAL_REGEX from localstack.http import Request, Response, Router, route from localstack.services.edge import ROUTER from localstack.services.moto import call_moto @@ -87,7 +90,7 @@ SnsPublishContext, ) from localstack.utils.aws import aws_stack -from localstack.utils.aws.arns import parse_arn +from localstack.utils.aws.arns import extract_region_from_arn, parse_arn from localstack.utils.strings import short_uid # set up logger @@ -110,6 +113,28 @@ def on_after_init(self): def get_store() -> SnsStore: return sns_stores[get_aws_account_id()][aws_stack.get_region()] + @staticmethod + @contextmanager + def modify_context_region_with_arn_region(context: RequestContext, arn: str): + # TODO@viren this is pretty similar to DynamoDB's `modify_context_request()` + original_region = context.region + original_auth_header = context.request.headers.get("Authorization") + + arn_region = extract_region_from_arn(arn) + context.region = arn_region + + context.request.headers["Authorization"] = re.sub( + AUTH_CREDENTIAL_REGEX, + rf"Credential=\1/\2/{arn_region}/\4/", + original_auth_header or "", + flags=re.IGNORECASE, + ) + + yield context + + context.region = original_region + context.request.headers["Authorization"] = original_auth_header + def add_permission( self, context: RequestContext, @@ -248,7 +273,8 @@ def set_topic_attributes( attribute_name: attributeName, attribute_value: attributeValue = None, ) -> None: - call_moto(context) + with self.modify_context_region_with_arn_region(context, topic_arn): + call_moto(context) def verify_sms_sandbox_phone_number( self, context: RequestContext, phone_number: PhoneNumberString, one_time_password: OTPCode @@ -259,7 +285,8 @@ def verify_sms_sandbox_phone_number( def get_topic_attributes( self, context: RequestContext, topic_arn: topicARN ) -> GetTopicAttributesResponse: - moto_response = call_moto(context) + with self.modify_context_region_with_arn_region(context, topic_arn): + moto_response = call_moto(context) # todo fix some attributes by moto, see snapshot return GetTopicAttributesResponse(**moto_response) @@ -275,7 +302,7 @@ def publish_batch( ) store = self.get_store() - if topic_arn not in store.sns_subscriptions: + if topic_arn not in store.SNS_SUBSCRIPTIONS: raise NotFoundException( "Topic does not exist", ) @@ -362,7 +389,7 @@ def set_subscription_attributes( raise InvalidParameterException( "Invalid parameter: FilterPolicy: failed to parse JSON." ) - store.subscription_filter_policy[subscription_arn] = filter_policy + store.SUBSCRIPTION_FILTER_POLICY[subscription_arn] = filter_policy pass elif attribute_name == "RawMessageDelivery": # TODO: only for SQS and https(s) subs, + firehose @@ -404,11 +431,11 @@ def confirm_subscription( sub_arn = None # TODO: this is false, we validate only one sub and not all for topic # WRITE AWS VALIDATED TEST FOR IT - for k, v in store.subscription_status.items(): + for k, v in store.SUBSCRIPTION_STATUS.items(): if v.get("Token") == token and v["TopicArn"] == topic_arn: v["Status"] = "Subscribed" sub_arn = k - for k, v in store.sns_subscriptions.items(): + for k, v in store.SNS_SUBSCRIPTIONS.items(): for i in v: if i["TopicArn"] == topic_arn: i["PendingConfirmation"] = "false" @@ -420,7 +447,7 @@ def untag_resource( ) -> UntagResourceResponse: call_moto(context) store = self.get_store() - store.sns_tags[resource_arn] = [ + store.SNS_TAGS[resource_arn] = [ t for t in _get_tags(resource_arn) if t["Key"] not in tag_keys ] return UntagResourceResponse() @@ -503,8 +530,8 @@ def should_be_kept(current_subscription: SnsSubscription, target_subscription_ar return False - for topic_arn, existing_subs in store.sns_subscriptions.items(): - store.sns_subscriptions[topic_arn] = [ + for topic_arn, existing_subs in store.SNS_SUBSCRIPTIONS.items(): + store.SNS_SUBSCRIPTIONS[topic_arn] = [ sub for sub in existing_subs if should_be_kept(sub, subscription_arn) ] @@ -600,7 +627,7 @@ def publish( ) else: topic = topic_arn or target_arn - if topic not in store.sns_subscriptions: + if topic not in store.SNS_SUBSCRIPTIONS: raise NotFoundException( "Topic does not exist", ) @@ -674,8 +701,8 @@ def subscribe( subscription_arn = moto_response.get("SubscriptionArn") filter_policy = moto_response.get("FilterPolicy") store = self.get_store() - topic_subs = store.sns_subscriptions[topic_arn] = ( - store.sns_subscriptions.get(topic_arn) or [] + topic_subs = store.SNS_SUBSCRIPTIONS[topic_arn] = ( + store.SNS_SUBSCRIPTIONS.get(topic_arn) or [] ) # An endpoint may only be subscribed to a topic once. Subsequent # subscribe calls do nothing (subscribe is idempotent). @@ -685,7 +712,7 @@ def subscribe( SubscriptionArn=existing_topic_subscription["SubscriptionArn"] ) if filter_policy: - store.subscription_filter_policy[subscription_arn] = json.loads(filter_policy) + store.SUBSCRIPTION_FILTER_POLICY[subscription_arn] = json.loads(filter_policy) subscription = { # http://docs.aws.amazon.com/cli/latest/reference/sns/get-subscription-attributes.html @@ -700,11 +727,11 @@ def subscribe( subscription.update(attributes) topic_subs.append(subscription) - if subscription_arn not in store.subscription_status: - store.subscription_status[subscription_arn] = {} + if subscription_arn not in store.SUBSCRIPTION_STATUS: + store.SUBSCRIPTION_STATUS[subscription_arn] = {} subscription_token = short_uid() - store.subscription_status[subscription_arn].update( + store.SUBSCRIPTION_STATUS[subscription_arn].update( {"TopicArn": topic_arn, "Token": subscription_token, "Status": "Not Subscribed"} ) # Send out confirmation message for HTTP(S), fix for https://github.com/localstack/localstack/issues/881 @@ -739,7 +766,7 @@ def tag_resource( call_moto(context) store = self.get_store() - existing_tags = store.sns_tags.get(resource_arn, []) + existing_tags = store.SNS_TAGS.get(resource_arn, []) def existing_tag_index(item): for idx, tag in enumerate(existing_tags): @@ -754,14 +781,15 @@ def existing_tag_index(item): else: existing_tags[existing_index] = item - store.sns_tags[resource_arn] = existing_tags + store.SNS_TAGS[resource_arn] = existing_tags return TagResourceResponse() def delete_topic(self, context: RequestContext, topic_arn: topicARN) -> None: - call_moto(context) + with self.modify_context_region_with_arn_region(context, topic_arn): + call_moto(context) store = self.get_store() - store.sns_subscriptions.pop(topic_arn, None) - store.sns_tags.pop(topic_arn, None) + store.SNS_SUBSCRIPTIONS.pop(topic_arn, None) + store.SNS_TAGS.pop(topic_arn, None) def create_topic( self, @@ -781,7 +809,7 @@ def create_topic( ) if tags: self.tag_resource(context=context, resource_arn=topic_arn, tags=tags) - store.sns_subscriptions[topic_arn] = store.sns_subscriptions.get(topic_arn) or [] + store.SNS_SUBSCRIPTIONS[topic_arn] = store.SNS_SUBSCRIPTIONS.get(topic_arn) or [] return CreateTopicResponse(TopicArn=topic_arn) @@ -789,7 +817,7 @@ def get_subscription_by_arn(sub_arn): store = SnsProvider.get_store() # TODO maintain separate map instead of traversing all items # how to deprecate the store without breaking pods/persistence - for key, subscriptions in store.sns_subscriptions.items(): + for key, subscriptions in store.SNS_SUBSCRIPTIONS.items(): for sub in subscriptions: if sub["SubscriptionArn"] == sub_arn: return sub @@ -797,10 +825,10 @@ def get_subscription_by_arn(sub_arn): def _get_tags(topic_arn): store = SnsProvider.get_store() - if topic_arn not in store.sns_tags: - store.sns_tags[topic_arn] = [] + if topic_arn not in store.SNS_TAGS: + store.SNS_TAGS[topic_arn] = [] - return store.sns_tags[topic_arn] + return store.SNS_TAGS[topic_arn] def is_raw_message_delivery(susbcriber): @@ -877,8 +905,8 @@ def validate_message_attribute_name(name: str) -> None: def extract_tags(topic_arn, tags, is_create_topic_request, store): - existing_tags = list(store.sns_tags.get(topic_arn, [])) - existing_sub = store.sns_subscriptions.get(topic_arn, None) + existing_tags = list(store.SNS_TAGS.get(topic_arn, [])) + existing_sub = store.SNS_SUBSCRIPTIONS.get(topic_arn, None) # if this is none there is nothing to check if existing_sub is not None: if tags is None: diff --git a/localstack/services/sns/publisher.py b/localstack/services/sns/publisher.py index 921311d155437..2895bfef43583 100644 --- a/localstack/services/sns/publisher.py +++ b/localstack/services/sns/publisher.py @@ -997,7 +997,7 @@ def _should_publish( Validate that the message should be relayed to the subscriber, depending on the filter policy """ subscriber_arn = subscriber["SubscriptionArn"] - filter_policy = store.subscription_filter_policy.get(subscriber_arn) + filter_policy = store.SUBSCRIPTION_FILTER_POLICY.get(subscriber_arn) if not filter_policy: return True # default value is `MessageAttributes` @@ -1011,7 +1011,7 @@ def _should_publish( return True def publish_to_topic(self, ctx: SnsPublishContext, topic_arn: str) -> None: - subscriptions = ctx.store.sns_subscriptions.get(topic_arn, []) + subscriptions = ctx.store.SNS_SUBSCRIPTIONS.get(topic_arn, []) for subscriber in subscriptions: if self._should_publish(ctx.store, ctx.message, subscriber): notifier = self.topic_notifiers[subscriber["Protocol"]] @@ -1026,7 +1026,7 @@ def publish_to_topic(self, ctx: SnsPublishContext, topic_arn: str) -> None: self.executor.submit(notifier.publish, context=ctx, subscriber=subscriber) def publish_batch_to_topic(self, ctx: SnsBatchPublishContext, topic_arn: str) -> None: - subscriptions = ctx.store.sns_subscriptions.get(topic_arn, []) + subscriptions = ctx.store.SNS_SUBSCRIPTIONS.get(topic_arn, []) for subscriber in subscriptions: protocol = subscriber["Protocol"] notifier = self.batch_topic_notifiers.get(protocol) @@ -1114,7 +1114,7 @@ def publish_to_topic_subscriber( :param subscription_arn: the ARN of the subscriber :return: None """ - subscriptions: List[SnsSubscription] = ctx.store.sns_subscriptions.get(topic_arn, []) + subscriptions: List[SnsSubscription] = ctx.store.SNS_SUBSCRIPTIONS.get(topic_arn, []) for subscriber in subscriptions: if subscriber["SubscriptionArn"] == subscription_arn: notifier = self.topic_notifiers[subscriber["Protocol"]] diff --git a/tests/integration/test_sns.py b/tests/integration/test_sns.py index 781e2737fe619..a3cbed2229267 100644 --- a/tests/integration/test_sns.py +++ b/tests/integration/test_sns.py @@ -622,7 +622,7 @@ def test_topic_subscription(self, sns_client, sns_create_topic, sns_subscription def check_subscription(): subscription_arn = subscription["SubscriptionArn"] - subscription_obj = sns_backend.subscription_status[subscription_arn] + subscription_obj = sns_backend.SUBSCRIPTION_STATUS[subscription_arn] assert subscription_obj["Status"] == "Not Subscribed" _token = subscription_obj["Token"] From 02477398b388f3fd19165ff52b8c9f532d444721 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Tue, 10 Jan 2023 14:51:29 +0530 Subject: [PATCH 09/18] Fixes --- localstack/services/sns/provider.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/localstack/services/sns/provider.py b/localstack/services/sns/provider.py index 1804dbe85287d..af79e67b261a5 100644 --- a/localstack/services/sns/provider.py +++ b/localstack/services/sns/provider.py @@ -90,7 +90,11 @@ SnsPublishContext, ) from localstack.utils.aws import aws_stack -from localstack.utils.aws.arns import extract_region_from_arn, parse_arn +from localstack.utils.aws.arns import ( + extract_account_id_from_arn, + extract_region_from_arn, + parse_arn, +) from localstack.utils.strings import short_uid # set up logger @@ -115,23 +119,32 @@ def get_store() -> SnsStore: @staticmethod @contextmanager - def modify_context_region_with_arn_region(context: RequestContext, arn: str): + def modify_context_with_arn(context: RequestContext, arn: str): # TODO@viren this is pretty similar to DynamoDB's `modify_context_request()` + original_account_id = context.account_id original_region = context.region original_auth_header = context.request.headers.get("Authorization") arn_region = extract_region_from_arn(arn) context.region = arn_region + arn_account_id = extract_account_id_from_arn(arn) + context.account_id = arn_account_id + context.request.headers["Authorization"] = re.sub( AUTH_CREDENTIAL_REGEX, - rf"Credential=\1/\2/{arn_region}/\4/", + rf"Credential={arn_account_id}/\2/{arn_region}/\4/", original_auth_header or "", flags=re.IGNORECASE, ) + # TODO@viren:THis header is currently added in handler chain + # consider moving it to `call_moto()` + context.request.headers["x-moto-account-id"] = arn_account_id + yield context + context.account_id = original_account_id context.region = original_region context.request.headers["Authorization"] = original_auth_header @@ -273,7 +286,7 @@ def set_topic_attributes( attribute_name: attributeName, attribute_value: attributeValue = None, ) -> None: - with self.modify_context_region_with_arn_region(context, topic_arn): + with self.modify_context_with_arn(context, topic_arn): call_moto(context) def verify_sms_sandbox_phone_number( @@ -285,7 +298,7 @@ def verify_sms_sandbox_phone_number( def get_topic_attributes( self, context: RequestContext, topic_arn: topicARN ) -> GetTopicAttributesResponse: - with self.modify_context_region_with_arn_region(context, topic_arn): + with self.modify_context_with_arn(context, topic_arn): moto_response = call_moto(context) # todo fix some attributes by moto, see snapshot return GetTopicAttributesResponse(**moto_response) @@ -785,7 +798,7 @@ def existing_tag_index(item): return TagResourceResponse() def delete_topic(self, context: RequestContext, topic_arn: topicARN) -> None: - with self.modify_context_region_with_arn_region(context, topic_arn): + with self.modify_context_with_arn(context, topic_arn): call_moto(context) store = self.get_store() store.SNS_SUBSCRIPTIONS.pop(topic_arn, None) From 0b34651f0b7054ac29f39898952c6a095f2536e7 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Tue, 10 Jan 2023 19:20:36 +0530 Subject: [PATCH 10/18] Add tests --- localstack/services/sns/provider.py | 1 + localstack/testing/pytest/fixtures.py | 8 ++++++++ tests/integration/test_multi_accounts.py | 11 ----------- tests/integration/test_sns.py | 15 +++++++++++++++ 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/localstack/services/sns/provider.py b/localstack/services/sns/provider.py index af79e67b261a5..0a056706b434d 100644 --- a/localstack/services/sns/provider.py +++ b/localstack/services/sns/provider.py @@ -121,6 +121,7 @@ def get_store() -> SnsStore: @contextmanager def modify_context_with_arn(context: RequestContext, arn: str): # TODO@viren this is pretty similar to DynamoDB's `modify_context_request()` + # TODO@viren maybe we can make this `call_moto_with_arn_context` original_account_id = context.account_id original_region = context.region original_auth_header = context.request.headers.get("Authorization") diff --git a/localstack/testing/pytest/fixtures.py b/localstack/testing/pytest/fixtures.py index 0f588631b27b3..ef11c64302705 100644 --- a/localstack/testing/pytest/fixtures.py +++ b/localstack/testing/pytest/fixtures.py @@ -135,6 +135,14 @@ def _resource(service): return aws_stack.connect_to_resource_external(service, config=config) +@pytest.fixture +def client_factory(): + def _client_factory(service: str, aws_access_key_id: str, region_name: str = "eu-central-1"): + return _client(service, region_name=region_name, aws_access_key_id=aws_access_key_id) + + yield _client_factory + + @pytest.fixture(scope="class") def create_boto_client(): return _client diff --git a/tests/integration/test_multi_accounts.py b/tests/integration/test_multi_accounts.py index cf1634d50cb07..0437e39fbcf38 100644 --- a/tests/integration/test_multi_accounts.py +++ b/tests/integration/test_multi_accounts.py @@ -1,17 +1,6 @@ -import pytest - -from localstack.testing.pytest.fixtures import _client from localstack.utils.strings import short_uid -@pytest.fixture -def client_factory(): - def _client_factory(service: str, aws_access_key_id: str, region_name: str = "eu-central-1"): - return _client(service, region_name=region_name, aws_access_key_id=aws_access_key_id) - - yield _client_factory - - class TestMultiAccounts: def test_account_id_namespacing_for_moto_backends(self, client_factory): # diff --git a/tests/integration/test_sns.py b/tests/integration/test_sns.py index a3cbed2229267..ae0ae604cc56d 100644 --- a/tests/integration/test_sns.py +++ b/tests/integration/test_sns.py @@ -2985,3 +2985,18 @@ def test_message_structure_json_exc(self, sns_client, sns_create_topic, snapshot MessageStructure="json", ) snapshot.match("duplicate-json-keys", resp) + + def test_sns_cross_account_access(self, client_factory): + # Ensure a topic can be retrieved from any region and account ID + sns_client1 = client_factory( + "sns", aws_access_key_id="424242424242", region_name="eu-central-1" + ) + sns_client2 = client_factory( + "sns", aws_access_key_id="100010001000", region_name="ap-south-1" + ) + + topic_name = f"topic-{short_uid()}" + + topic_arn = sns_client1.create_topic(Name=topic_name)["TopicArn"] + + assert sns_client2.get_topic_attributes(TopicArn=topic_arn) From 8ec3aa98c9a47b87f02cf8224d11886ab1c12807 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Thu, 12 Jan 2023 14:53:53 +0530 Subject: [PATCH 11/18] Remove duplicate assignment --- localstack/services/stores.py | 1 - 1 file changed, 1 deletion(-) diff --git a/localstack/services/stores.py b/localstack/services/stores.py index f7f59578be275..3aa045cfebb89 100644 --- a/localstack/services/stores.py +++ b/localstack/services/stores.py @@ -208,7 +208,6 @@ def __init__( self.service_name = service_name self.validate = validate self.lock = lock or RLock() - self._universal = universal self.valid_regions = get_valid_regions_for_service(service_name) From 548ced386776433a570bfdd578d921ee224c71db Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Thu, 12 Jan 2023 15:09:00 +0530 Subject: [PATCH 12/18] Fallback to default internal credentials --- localstack/aws/client.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/localstack/aws/client.py b/localstack/aws/client.py index c8c0a3fb866a4..f68e3c0c6e78e 100644 --- a/localstack/aws/client.py +++ b/localstack/aws/client.py @@ -376,15 +376,6 @@ def with_credentials( ) ) - def with_default_credentials(self) -> "ClientFactory": - """ - Use LocalStack default AWS credentials. - """ - return self.with_credentials( - aws_access_key_id=INTERNAL_AWS_ACCESS_KEY_ID, - aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY, - ) - def with_env_credentials(self) -> "ClientFactory": """ Use AWS credentials from the environment. @@ -406,9 +397,10 @@ def build(self, service: str) -> BaseClient: """ Finalise the client. """ - assert self.client_options.aws_access_key_id, "Access key ID is not set" - assert self.client_options.aws_secret_access_key, "Secret access key is not set" - + aws_access_key_id = self.client_options.aws_access_key_id or INTERNAL_AWS_ACCESS_KEY_ID + aws_secret_access_key = ( + self.client_options.aws_secret_access_key or INTERNAL_AWS_SECRET_ACCESS_KEY + ) endpoint_url = self.client_options.endpoint_url or get_local_service_url(service) # TODO@viren: creating a boto client is very intensive. In old aws_stack, we cache clients based on @@ -417,8 +409,8 @@ def build(self, service: str) -> BaseClient: client = self.session.client( service_name=service, config=self.client_options.boto_config, - aws_access_key_id=self.client_options.aws_access_key_id, - aws_secret_access_key=self.client_options.aws_secret_access_key, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, endpoint_url=endpoint_url, ) From f6f37fa09aaea72e59ddffcbb5b5228921235fbb Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Thu, 12 Jan 2023 16:01:27 +0530 Subject: [PATCH 13/18] Proper loading of default credentials --- localstack/aws/client.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/localstack/aws/client.py b/localstack/aws/client.py index f68e3c0c6e78e..bfb553e51327b 100644 --- a/localstack/aws/client.py +++ b/localstack/aws/client.py @@ -378,10 +378,13 @@ def with_credentials( def with_env_credentials(self) -> "ClientFactory": """ - Use AWS credentials from the environment. + Use AWS credentials from the following locations: + - Environment variables + - Credentials file `~/.aws/credentials` + - Config file `~/.aws/config` """ - # TODO wrong output format of session.get_credentials() - return self.credentials(self.session.get_credentials()) + credentials = self.session.get_credentials() + return self.with_credentials(credentials.access_key, credentials.secret_key) def with_boto_config(self, config: BotoConfig) -> "ClientFactory": """ From f3922420fb13682dc4db91977c57a637ec788067 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Thu, 12 Jan 2023 21:10:14 +0530 Subject: [PATCH 14/18] Move to its own module --- localstack/aws/client.py | 219 +------------------------------------- localstack/aws/connect.py | 207 +++++++++++++++++++++++++++++++++++ 2 files changed, 208 insertions(+), 218 deletions(-) create mode 100644 localstack/aws/connect.py diff --git a/localstack/aws/client.py b/localstack/aws/client.py index bfb553e51327b..726473c63b70c 100644 --- a/localstack/aws/client.py +++ b/localstack/aws/client.py @@ -1,29 +1,17 @@ """Utils to process AWS requests as a client.""" -import dataclasses import io -import json import logging from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, TypedDict +from typing import Dict, Iterable, Optional -from boto3 import Session -from botocore.awsrequest import AWSPreparedRequest -from botocore.client import BaseClient -from botocore.config import Config as BotoConfig from botocore.model import OperationModel from botocore.parsers import ResponseParser, ResponseParserFactory from werkzeug import Response from localstack.aws.api import CommonServiceException, ServiceException, ServiceResponse -from localstack.constants import INTERNAL_AWS_ACCESS_KEY_ID, INTERNAL_AWS_SECRET_ACCESS_KEY from localstack.runtime import hooks -from localstack.utils.aws.arns import extract_region_from_arn -from localstack.utils.aws.aws_stack import get_local_service_url from localstack.utils.patch import patch -if TYPE_CHECKING: - from mypy_boto3_sqs import SQSClient - LOG = logging.getLogger(__name__) @@ -232,208 +220,3 @@ def raise_service_exception(response: Response, parsed_response: Dict) -> None: """ if service_exception := parse_service_exception(response, parsed_response): raise service_exception - - -# -# Internal AWS client -# - -""" -The internal AWS client API provides the means to perform cross-service -communication within LocalStack. - -Any additional information LocalStack might need for the purpose of policy -enforcement is sent as a data transfer object. This is a serialised dict object -sent in the request header. -""" - -LOCALSTACK_DATA_HEADER = "x-localstack-data" -"""Request header which contains the data transfer object.""" - - -class LocalStackData(TypedDict): - """ - LocalStack Data Transfer Object. - """ - - source_arn: str - source_service: str # eg. 'ec2.amazonaws.com' - - -class Credentials(TypedDict): - """ - AWS credentials. - """ - - aws_access_key_id: str - aws_secret_access_key: str - aws_session_token: str - - -@dataclasses.dataclass(frozen=True) -class ClientOptions: - """This object holds configuration options for the internal AWS client.""" - - region_name: Optional[str] = None - """Name of the AWS region to be associated with the client.""" - - endpoint_url: Optional[str] = None - """Full endpoint URL to be used by the client.""" - - use_ssl: bool = True - """Whether or not to use SSL.""" - - verify: bool = True - """Whether or not to verify SSL certificates.""" - - aws_access_key_id: Optional[str] = None - """Access key to use for the client.""" - - aws_secret_access_key: Optional[str] = None - """Secret key to use for the client.""" - - aws_session_token: Optional[str] = None - """Session token to use for the client""" - - boto_config: Optional[BotoConfig] = dataclasses.field(default_factory=BotoConfig) - """Boto client configuration for advanced use.""" - - localstack_data: dict[str, Any] = dataclasses.field(default_factory=LocalStackData) - """LocalStack data transfer object.""" - - -class ClientFactory: - """ - Factory to build the internal AWS client. - """ - - # TODO migrate to immutable clientfactory instances - client_options: ClientOptions - session: Session - - def __init__(self, client_options: ClientOptions = None): - self.client_options = client_options or ClientOptions() - self.session = Session() - - def with_endpoint(self, endpoint: str) -> "ClientFactory": - """ - Set a custom endpoint. - """ - return ClientFactory( - client_options=dataclasses.replace(self.client_options, endpoint_url=endpoint) - ) - - def with_source_arn(self, arn: str) -> "ClientFactory": - """ - Indicate that the client is operating from a given resource. - - This must be used in cross-service requests. - """ - return ClientFactory( - client_options=dataclasses.replace( - self.client_options, - localstack_data=self.client_options.localstack_data - | LocalStackData(source_arn=arn), - ) - ) - - def with_target_arn(self, arn: str) -> "ClientFactory": - """ - Create the client to operate on a target resource. - - This must be used in cross-service requests. - """ - region_name = extract_region_from_arn(arn) - return ClientFactory( - client_options=dataclasses.replace(self.client_options, region_name=region_name) - ) - - def with_source_service_principal(self, source_service: str) -> "ClientFactory": - """ - Set the source service principal. - - This must be used in cross-service requests. - """ - return ClientFactory( - client_options=dataclasses.replace( - self.client_options, - localstack_data=self.client_options.localstack_data - | LocalStackData(source_service=f"{source_service}.amazonaws.com"), - ) - ) - - def with_credentials( - self, aws_access_key_id: str, aws_secret_access_key: str - ) -> "ClientFactory": - """ - Use custom AWS credentials. - """ - return ClientFactory( - client_options=dataclasses.replace( - self.client_options, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - ) - ) - - def with_env_credentials(self) -> "ClientFactory": - """ - Use AWS credentials from the following locations: - - Environment variables - - Credentials file `~/.aws/credentials` - - Config file `~/.aws/config` - """ - credentials = self.session.get_credentials() - return self.with_credentials(credentials.access_key, credentials.secret_key) - - def with_boto_config(self, config: BotoConfig) -> "ClientFactory": - """ - Use a custom BotoConfig. - """ - return ClientFactory( - client_options=dataclasses.replace( - self.client_options, boto_config=self.client_options.boto_config.merge(config) - ) - ) - - def build(self, service: str) -> BaseClient: - """ - Finalise the client. - """ - aws_access_key_id = self.client_options.aws_access_key_id or INTERNAL_AWS_ACCESS_KEY_ID - aws_secret_access_key = ( - self.client_options.aws_secret_access_key or INTERNAL_AWS_SECRET_ACCESS_KEY - ) - endpoint_url = self.client_options.endpoint_url or get_local_service_url(service) - - # TODO@viren: creating a boto client is very intensive. In old aws_stack, we cache clients based on - # [service_name, client, env, region, endpoint_url, config, internal, kwargs] - # Come up with an appropriate solution here - client = self.session.client( - service_name=service, - config=self.client_options.boto_config, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - endpoint_url=endpoint_url, - ) - - def event_handler(request: AWSPreparedRequest, **_): - # Send a compact JSON representation as DTO - request.headers[LOCALSTACK_DATA_HEADER] = json.dumps( - self.client_options.localstack_data, separators=(",", ":") - ) - - client.meta.events.register("before-send.*.*", handler=event_handler) - - return client - - # - # Convenience helpers - # - - def sqs(self) -> "SQSClient": - return self.build("sqs") - - -def aws_client(): - return ClientFactory() diff --git a/localstack/aws/connect.py b/localstack/aws/connect.py new file mode 100644 index 0000000000000..864f714dd1201 --- /dev/null +++ b/localstack/aws/connect.py @@ -0,0 +1,207 @@ +""" +Internal AWS client. + +This module provides the interface to perform cross-service communication between +LocalStack providers. + + from localstack.aws.connect import connect_to + + key_pairs = connect_to('ec2').describe_key_pairs() + buckets = connect_to('s3', region='ap-south-1').list_buckets() +""" +import json +from datetime import datetime, timezone +from functools import cache +from typing import Optional, TypedDict + +from boto3.session import Session +from botocore.awsrequest import AWSPreparedRequest +from botocore.client import BaseClient +from botocore.config import Config + +from localstack import config +from localstack.constants import ( + INTERNAL_AWS_ACCESS_KEY_ID, + INTERNAL_AWS_SECRET_ACCESS_KEY, + MAX_POOL_CONNECTIONS, +) +from localstack.utils.aws.arns import extract_region_from_arn +from localstack.utils.aws.aws_stack import get_local_service_url +from localstack.utils.aws.request_context import get_region_from_request_context + +# +# Data transfer object +# + +LOCALSTACK_DATA_HEADER = "x-localstack-data" +"""Request header which contains the data transfer object.""" + + +class LocalStackData(TypedDict): + """ + LocalStack Data Transfer Object. + + This is sent with every internal request and contains any additional information + LocalStack might need for the purpose of policy enforcement. It is serialised + into text and sent in the request header. + + The keys approximately correspond to: + https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_condition-keys.html + """ + + current_time: str + """Request datetime in ISO8601 format""" + + source_arn: str + """ARN of resource which is triggering the call""" + + source_service: str + """Service principal where the call originates, eg. `ec2`""" + + target_arn: str + """ARN of the resource being targetted.""" + + +def dump_dto(data: LocalStackData) -> str: + # TODO@viren: Improve minification using custom JSONEncoder that use shortened keys + return json.dumps(data, separators=(",", ":")) + + +def load_dto(data: str) -> LocalStackData: + return json.loads(data) + + +# +# Client +# + + +class ConnectFactory: + """ + Factory to build the internal AWS client. + """ + + def __init__( + self, + use_ssl: bool = False, + verify: bool = False, + aws_access_key_id: Optional[str] = INTERNAL_AWS_ACCESS_KEY_ID, + aws_secret_access_key: Optional[str] = INTERNAL_AWS_SECRET_ACCESS_KEY, + ): + """ + If either of the access keys are set to None, they are loaded from following + locations: + - AWS environment variables + - Credentials file `~/.aws/credentials` + - Config file `~/.aws/config` + + :param use_ssl: Whether to use SSL + :param verify: Whether to verify SSL certificates + :param aws_access_key_id: Access key to use for the client. + If set to None, loads them from botocore session. See above. + :param aws_secret_access_key: Secret key to use for the client. + If set to None, loads them from botocore session. See above. + :param localstack_data: LocalStack data transfer object + """ + self._use_ssl = use_ssl + self._verify = verify + self._aws_access_key_id = aws_access_key_id + self._aws_secret_access_key = aws_secret_access_key + self._aws_session_token = None + self._session = Session() + self._config = Config(max_pool_connections=MAX_POOL_CONNECTIONS) + + def get_region(self) -> str: + return ( + get_region_from_request_context() or config.DEFAULT_REGION or self._session.region_name + ) + + # TODO@viren is this thread safe? + @cache + def get_client( + self, + service_name: str, + region_name: str, + use_ssl: bool, + verify: bool, + endpoint_url: str, + aws_access_key_id: str, + aws_secret_access_key: str, + aws_session_token: str, + config: Config, + ) -> BaseClient: + return self._session.client( + service_name=service_name, + region_name=region_name, + use_ssl=use_ssl, + verify=verify, + endpoint_url=endpoint_url, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + config=config, + ) + + def __call__( + self, + target_service: str, + region_name: str = None, + endpoint_url: str = None, + config: Config = None, + source_arn: str = None, + source_service: str = None, + target_arn: str = None, + ) -> BaseClient: + """ + Build and return the client. + + :param target_service: Service to build the client for, eg. `s3` + :param region_name: Name of the AWS region to be associated with the client + :param endpoint_url: Full endpoint URL to be used by the client. + Defaults to appropriate LocalStack endpoint. + :param config: Boto config for advanced use. + :param source_arn: ARN of resource which triggers the call. Required for + internal calls. + :param source_service: Service name where call originates. Required for + internal calls. + :param target_arn: ARN of targeted resource. Overrides `region_name`. + Required for internal calls. + """ + localstack_data = LocalStackData() + + if source_arn: + localstack_data["source_arn"] = source_arn + + if source_service: + localstack_data["source_service"] = source_service + + if target_arn: + region_name = extract_region_from_arn(target_arn) + localstack_data["target_arn"] = target_arn + + client = self.get_client( + service_name=target_service, + region_name=region_name or self.get_region(), + use_ssl=self._use_ssl, + verify=self._verify, + endpoint_url=endpoint_url or get_local_service_url(target_service), + aws_access_key_id=self._aws_access_key_id, + aws_secret_access_key=self._aws_secret_access_key, + aws_session_token=self._aws_session_token, + config=config or self._config, + ) + + def _handler(request: AWSPreparedRequest, **_): + data = localstack_data | LocalStackData( + current_time=datetime.utcnow(timezone.utc).isoformat() + ) + + # Use a compact JSON representation of DTO + request.headers[LOCALSTACK_DATA_HEADER] = dump_dto(data) + + client.meta.events.register("before-send.*.*", handler=_handler) + + return client + + +connect_to = ConnectFactory() From f8edc9c7691fea055e64247805e89c9ddbf4e382 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Fri, 13 Jan 2023 17:51:53 +0530 Subject: [PATCH 15/18] Fix datetime --- localstack/aws/connect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/localstack/aws/connect.py b/localstack/aws/connect.py index 864f714dd1201..be410f73513c2 100644 --- a/localstack/aws/connect.py +++ b/localstack/aws/connect.py @@ -193,7 +193,7 @@ def __call__( def _handler(request: AWSPreparedRequest, **_): data = localstack_data | LocalStackData( - current_time=datetime.utcnow(timezone.utc).isoformat() + current_time=datetime.now(timezone.utc).isoformat() ) # Use a compact JSON representation of DTO From 3a2a669c5c90a5909bb3c3a62bc287e3e3a5391e Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Fri, 13 Jan 2023 17:52:18 +0530 Subject: [PATCH 16/18] WIP --- localstack/aws/handlers/auth.py | 18 ++++++++++++++++-- localstack/constants.py | 3 +++ localstack/services/ses/provider.py | 14 ++------------ 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/localstack/aws/handlers/auth.py b/localstack/aws/handlers/auth.py index 47e36bef89198..7fe099d1ebe25 100644 --- a/localstack/aws/handlers/auth.py +++ b/localstack/aws/handlers/auth.py @@ -5,7 +5,11 @@ set_aws_access_key_id, set_aws_account_id, ) -from localstack.constants import TEST_AWS_ACCESS_KEY_ID +from localstack.constants import ( + INTERNAL_AWS_ACCESS_KEY_ID, + INTERNAL_AWS_ACCOUNT_ID, + TEST_AWS_ACCESS_KEY_ID, +) from localstack.http import Response from localstack.utils.aws.aws_stack import extract_access_key_id_from_auth_header @@ -45,7 +49,17 @@ def __call__(self, chain: HandlerChain, context: RequestContext, response: Respo set_aws_access_key_id(access_key_id) # Obtain the account ID and save it in the request context - context.account_id = get_account_id_from_access_key_id(access_key_id) + if access_key_id == INTERNAL_AWS_ACCESS_KEY_ID: + # For internal calls, a special account ID is used for request context + # Cross account calls don't have the same auth flows as user-originating calls + # which means there is no true Account ID. + # The invocations of `get_aws_account_id()` used to resolve the stores must not break. + # We don't use the DEFAULT_AWS_ACCOUNT_ID either to help identify bugs. + # If correctly implemented with CrossAccountAttribute and ARNs, the provider + # will work with this internal AWS account ID. + context.account_id = INTERNAL_AWS_ACCOUNT_ID + else: + context.account_id = get_account_id_from_access_key_id(access_key_id) # Save the same account ID in the thread context set_aws_account_id(context.account_id) diff --git a/localstack/constants.py b/localstack/constants.py index 31b7288fd0350..9d20d748ecc5b 100644 --- a/localstack/constants.py +++ b/localstack/constants.py @@ -45,6 +45,9 @@ # Fallback Account ID if not available in the client request DEFAULT_AWS_ACCOUNT_ID = "000000000000" +# Fallback Account ID for internal calls +INTERNAL_AWS_ACCOUNT_ID = "000012210000" + # AWS user account ID used for tests - TODO move to config.py if "TEST_AWS_ACCOUNT_ID" not in os.environ: os.environ["TEST_AWS_ACCOUNT_ID"] = DEFAULT_AWS_ACCOUNT_ID diff --git a/localstack/services/ses/provider.py b/localstack/services/ses/provider.py index a0af7f6ba805d..e8f94933b45b4 100644 --- a/localstack/services/ses/provider.py +++ b/localstack/services/ses/provider.py @@ -50,12 +50,11 @@ VerificationAttributes, VerificationStatus, ) -from localstack.constants import TEST_AWS_SECRET_ACCESS_KEY +from localstack.aws.connect import connect_to from localstack.services.internal import get_internal_apis from localstack.services.moto import call_moto from localstack.services.plugins import ServiceLifecycleHook from localstack.services.ses.models import SentEmail, SentEmailBody -from localstack.utils.aws import arns, aws_stack from localstack.utils.files import mkdir from localstack.utils.strings import long_uid, to_str from localstack.utils.time import timestamp, timestamp_millis @@ -577,13 +576,4 @@ def emit_delivery_event(self, payload: SNSPayload, sns_topic_arn: str): @staticmethod def _client_for_topic(topic_arn: str) -> "SNSClient": - arn_parameters = arns.parse_arn(topic_arn) - region = arn_parameters["region"] - access_key_id = arn_parameters["account"] - - return aws_stack.connect_to_service( - "sns", - region_name=region, - aws_access_key_id=access_key_id, - aws_secret_access_key=TEST_AWS_SECRET_ACCESS_KEY, - ) + return connect_to("sns", target_arn=topic_arn, source_service="ses") From b53e0684d6610c43fc046c349602ecb510f67648 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Mon, 16 Jan 2023 15:45:01 +0530 Subject: [PATCH 17/18] Allow module to be used for external clients also --- localstack/aws/connect.py | 56 ++++++++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/localstack/aws/connect.py b/localstack/aws/connect.py index be410f73513c2..a17a2e9dda4f0 100644 --- a/localstack/aws/connect.py +++ b/localstack/aws/connect.py @@ -1,5 +1,5 @@ """ -Internal AWS client. +LocalStack client stack. This module provides the interface to perform cross-service communication between LocalStack providers. @@ -20,6 +20,7 @@ from botocore.config import Config from localstack import config +from localstack.aws.api import RequestContext from localstack.constants import ( INTERNAL_AWS_ACCESS_KEY_ID, INTERNAL_AWS_SECRET_ACCESS_KEY, @@ -59,11 +60,13 @@ class LocalStackData(TypedDict): """Service principal where the call originates, eg. `ec2`""" target_arn: str - """ARN of the resource being targetted.""" + """ARN of the resource being targeted.""" def dump_dto(data: LocalStackData) -> str: # TODO@viren: Improve minification using custom JSONEncoder that use shortened keys + + # To produce a compact JSON representation of DTO, remove spaces from separators return json.dumps(data, separators=(",", ":")) @@ -78,7 +81,7 @@ def load_dto(data: str) -> LocalStackData: class ConnectFactory: """ - Factory to build the internal AWS client. + Factory to build the AWS client. """ def __init__( @@ -111,12 +114,31 @@ def __init__( self._session = Session() self._config = Config(max_pool_connections=MAX_POOL_CONNECTIONS) - def get_region(self) -> str: + def get_partition_for_region(self, region_name: str) -> str: + """ + Return the AWS partition name for a given region, eg. `aws`, `aws-cn`, etc. + """ + return self._session.get_partition_for_region(region_name) + + def get_session_region_name(self) -> str: + """ + Return AWS region as set in the Boto session. + """ + return self._session.region_name + + def get_region_name(self) -> str: + """ + Return the AWS region name from following sources, in order of availability. + - LocalStack request context + - LocalStack default region + - Boto session + """ return ( - get_region_from_request_context() or config.DEFAULT_REGION or self._session.region_name + get_region_from_request_context() + or config.DEFAULT_REGION + or self.get_session_region_name() ) - # TODO@viren is this thread safe? @cache def get_client( self, @@ -155,6 +177,9 @@ def __call__( """ Build and return the client. + Presence of any attribute apart from `source_*` or `target_*` argument + indicates that this is a client meant for internal calls. + :param target_service: Service to build the client for, eg. `s3` :param region_name: Name of the AWS region to be associated with the client :param endpoint_url: Full endpoint URL to be used by the client. @@ -176,12 +201,13 @@ def __call__( localstack_data["source_service"] = source_service if target_arn: + # Attention: region is overriden here region_name = extract_region_from_arn(target_arn) localstack_data["target_arn"] = target_arn client = self.get_client( service_name=target_service, - region_name=region_name or self.get_region(), + region_name=region_name or self.get_region_name(), use_ssl=self._use_ssl, verify=self._verify, endpoint_url=endpoint_url or get_local_service_url(target_service), @@ -195,13 +221,23 @@ def _handler(request: AWSPreparedRequest, **_): data = localstack_data | LocalStackData( current_time=datetime.now(timezone.utc).isoformat() ) - - # Use a compact JSON representation of DTO request.headers[LOCALSTACK_DATA_HEADER] = dump_dto(data) - client.meta.events.register("before-send.*.*", handler=_handler) + if len(localstack_data): + client.meta.events.register("before-send.*.*", handler=_handler) return client connect_to = ConnectFactory() + +# +# Utilities +# + + +def is_internal_call(context: RequestContext) -> bool: + """ + Whether a given request is an internal LocalStack cross-service call. + """ + return LOCALSTACK_DATA_HEADER in context.request.headers From 169f3feed7d57461e81571cf86ba530ff061d24f Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Mon, 16 Jan 2023 15:51:11 +0530 Subject: [PATCH 18/18] Use new client at all places SNS was previously being used --- localstack/services/s3/notifications.py | 3 ++- localstack/services/s3/s3_listener.py | 6 ++++-- localstack/utils/aws/dead_letter_queue.py | 5 ++++- localstack/utils/aws/message_forwarding.py | 4 +++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/localstack/services/s3/notifications.py b/localstack/services/s3/notifications.py index 56ce022a6234c..b76ca9699e6db 100644 --- a/localstack/services/s3/notifications.py +++ b/localstack/services/s3/notifications.py @@ -28,6 +28,7 @@ TopicArn, TopicConfiguration, ) +from localstack.aws.connect import connect_to from localstack.config import DEFAULT_REGION from localstack.services.s3.models import get_moto_s3_backend from localstack.services.s3.utils import ( @@ -350,7 +351,7 @@ def _get_arn_value_and_name(topic_configuration: TopicConfiguration) -> [TopicAr return topic_configuration.get("TopicArn", ""), "TopicArn" def _verify_target_exists(self, arn: str, arn_data: ArnData) -> None: - client = aws_stack.connect_to_service(self.service_name, region_name=arn_data["region"]) + client = connect_to(self.service_name, target_arn=arn, source_service="s3") try: client.get_topic_attributes(TopicArn=arn) except ClientError: diff --git a/localstack/services/s3/s3_listener.py b/localstack/services/s3/s3_listener.py index c7527f60960fa..94eb52f14a561 100644 --- a/localstack/services/s3/s3_listener.py +++ b/localstack/services/s3/s3_listener.py @@ -21,6 +21,7 @@ from localstack import config, constants from localstack.aws.api import CommonServiceException +from localstack.aws.connect import connect_to from localstack.config import get_protocol as get_service_protocol from localstack.services.generic_proxy import ProxyListener from localstack.services.generic_proxy import append_cors_headers as _append_default_cors_headers @@ -368,10 +369,11 @@ def send_notification_for_subscriber( ) if notification.get("Topic"): region = arns.extract_region_from_arn(notification["Topic"]) - sns_client = aws_stack.connect_to_service("sns", region_name=region) + topic_arn = notification["Topic"] + sns_client = connect_to("sns", target_arn=topic_arn, source_service="s3") try: sns_client.publish( - TopicArn=notification["Topic"], + TopicArn=topic_arn, Message=message, Subject="Amazon S3 Notification", ) diff --git a/localstack/utils/aws/dead_letter_queue.py b/localstack/utils/aws/dead_letter_queue.py index a0afa9162b433..a3376cfe7bafa 100644 --- a/localstack/utils/aws/dead_letter_queue.py +++ b/localstack/utils/aws/dead_letter_queue.py @@ -3,6 +3,7 @@ import uuid from typing import Dict, List +from localstack.aws.connect import connect_to from localstack.utils.aws import arns, aws_stack from localstack.utils.aws.aws_models import LambdaFunction from localstack.utils.strings import convert_to_printable_chars, first_char_to_upper @@ -52,7 +53,9 @@ def _send_to_dead_letter_queue(source_arn: str, dlq_arn: str, event: Dict, error LOG.info(msg) raise Exception(msg) elif ":sns:" in dlq_arn: - sns_client = aws_stack.connect_to_service("sns") + sns_client = connect_to( + "sns", target_arn=dlq_arn, source_service="lambda", source_arn=source_arn + ) for message in messages: sns_client.publish( TopicArn=dlq_arn, diff --git a/localstack/utils/aws/message_forwarding.py b/localstack/utils/aws/message_forwarding.py index ab9f7f809db41..0f38096e63a74 100644 --- a/localstack/utils/aws/message_forwarding.py +++ b/localstack/utils/aws/message_forwarding.py @@ -7,6 +7,7 @@ from moto.events.models import events_backends +from localstack.aws.connect import connect_to from localstack.services.apigateway.helpers import extract_query_string_params from localstack.utils import collections from localstack.utils.aws.arns import ( @@ -35,6 +36,7 @@ def send_event_to_target( asynchronous: bool = True, target: Dict = None, ): + # TODO@viren Refactor to accept source ARN and source service, and send them with all `connect_to` calls region = extract_region_from_arn(target_arn) if target is None: target = {} @@ -48,7 +50,7 @@ def send_event_to_target( ) elif ":sns:" in target_arn: - sns_client = connect_to_service("sns", region_name=region) + sns_client = connect_to("sns", target_arn=target_arn) sns_client.publish(TopicArn=target_arn, Message=json.dumps(event)) elif ":sqs:" in target_arn: