8000 EventBridge: Multi-accounts compatibility (#9023) · codeperl/localstack@0cdaf55 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0cdaf55

Browse files
EventBridge: Multi-accounts compatibility (localstack#9023)
1 parent 7c6bec1 commit 0cdaf55

File tree

34 files changed

+288
-196
lines changed
  • testing/pytest
  • utils
  • tests
  • 34 files changed

    +288
    -196
    lines changed

    localstack/aws/connect.py

    Lines changed: 1 addition & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -132,7 +132,7 @@ def request_metadata(
    132132
    self, source_arn: str | None = None, service_principal: str | None = None
    133133
    ) -> T:
    134134
    """
    135-
    Provides request metadata to this client.
    135+
    Returns a new client instance preset with the given request metadata.
    136136
    Identical to providing _ServicePrincipal and _SourceArn directly as operation arguments but typing
    137137
    compatible.
    138138

    localstack/services/apigateway/helpers.py

    Lines changed: 5 additions & 4 deletions
    Original file line numberDiff line numberDiff line change
    @@ -746,18 +746,19 @@ def path_matches_pattern(path, api_path):
    746746
    return len(results) > 0 and all(results)
    747747

    748748

    749-
    def connect_api_gateway_to_sqs(gateway_name, stage_name, queue_arn, path, region_name=None):
    749+
    def connect_api_gateway_to_sqs(gateway_name, stage_name, queue_arn, path, account_id, region_name):
    750750
    resources = {}
    751751
    template = APIGATEWAY_SQS_DATA_INBOUND_TEMPLATE
    752752
    resource_path = path.replace("/", "")
    753-
    region_name = region_name or aws_stack.get_region()
    754753

    755754
    try:
    756755
    arn = parse_arn(queue_arn)
    757756
    queue_name = arn["resource"]
    757+
    sqs_account = arn["account"]
    758758
    sqs_region = arn["region"]
    759759
    except InvalidArnException:
    760760
    queue_name = queue_arn
    761+
    sqs_account = account_id
    761762
    sqs_region = region_name
    762763

    763764
    resources[resource_path] = [
    @@ -768,7 +769,7 @@ def connect_api_gateway_to_sqs(gateway_name, stage_name, queue_arn, path, region
    768769
    {
    769770
    "type": "AWS",
    770771
    "uri": "arn:aws:apigateway:%s:sqs:path/%s/%s"
    771-
    % (sqs_region, get_aws_account_id(), queue_name),
    772+
    % (sqs_region, sqs_account, queue_name),
    772773
    "requestTemplates": {"application/json": template},
    773774
    }
    774775
    ],
    @@ -778,7 +779,7 @@ def connect_api_gateway_to_sqs(gateway_name, stage_name, queue_arn, path, region
    778779
    name=gateway_name,
    779780
    resources=resources,
    780781
    stage_name=stage_name,
    781-
    region_name=region_name,
    782+
    client=connect_to(aws_access_key_id=sqs_account, region_name=sqs_region).apigateway,
    782783
    )
    783784

    784785

    localstack/services/cloudformation/models/es.py

    Lines changed: 1 addition & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -16,7 +16,7 @@ def es_add_tags_params(
    1616
    resource: dict,
    1717
    stack_name: str,
    1818
    ):
    19-
    es_arn = arns.es_domain_arn(properties.get("DomainName"))
    19+
    es_arn = arns.es_domain_arn(properties.get("DomainName"), account_id, region_name)
    2020
    tags = properties.get("Tags", [])
    2121
    return {"ARN": es_arn, "TagList": tags}
    2222

    localstack/services/cloudformation/models/lambda_.py

    Lines changed: 3 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -83,7 +83,9 @@ def update_resource(self, new_resource, stack_name, resources):
    8383
    k: str(v) for k, v in environment_variables.items()
    8484
    }
    8585
    result = client.update_function_configuration(**update_config_props)
    86-
    connect_to().lambda_.get_waiter("function_updated_v2").wait(FunctionName=function_name)
    86+
    connect_to(
    87+
    aws_access_key_id=self.account_id, region_name=self.region_name
    88+
    ).lambda_.get_waiter("function_updated_v2").wait(FunctionName=function_name)
    8789
    return result
    8890

    8991
    @staticmethod

    localstack/services/cloudformation/models/opensearch.py

    Lines changed: 1 addition & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -24,7 +24,7 @@ def opensearch_add_tags_params(
    2424
    resource: dict,
    2525
    stack_name: str,
    2626
    ):
    27-
    es_arn = arns.es_domain_arn(properties.get("DomainName"))
    27+
    es_arn = arns.es_domain_arn(properties.get("DomainName"), account_id, region_name)
    2828
    tags = properties.get("Tags", [])
    2929
    return {"ARN": es_arn, "TagList": tags}
    3030

    localstack/services/events/provider.py

    Lines changed: 1 addition & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -11,7 +11,6 @@
    1111
    from moto.events.responses import EventsHandler as MotoEventsHandler
    1212

    1313
    from localstack import config
    14-
    from localstack.aws.accounts import get_aws_account_id
    1514
    from localstack.aws.api import RequestContext
    1615
    from localstack.aws.api.core import CommonServiceException
    1716
    from localstack.aws.api.events import (
    @@ -583,7 +582,7 @@ def events_handler_put_events(self):
    583582
    "id": event_envelope["uuid"],
    584583
    "detail-type": event.get("DetailType"),
    585584
    "source": event.get("Source"),
    586-
    "account": get_aws_account_id(),
    585+
    "account": self.current_account,
    587586
    "time": event_time.strftime("%Y-%m-%dT%H:%M:%SZ"),
    588587
    "region": self.region,
    589588
    "resources": event.get("Resources", []),

    localstack/services/kinesis/provider.py

    Lines changed: 1 addition & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -46,7 +46,7 @@ def find_stream_for_consumer(consumer_arn):
    4646
    region_name = extract_region_from_arn(consumer_arn)
    4747
    kinesis = connect_to(aws_access_key_id=account_id, region_name=region_name).kinesis
    4848
    for stream_name in kinesis.list_streams()["StreamNames"]:
    49-
    stream_arn = arns.kinesis_stream_arn(stream_name)
    49+
    stream_arn = arns.kinesis_stream_arn(stream_name, account_id, region_name)
    5050
    for cons in kinesis.list_stream_consumers(StreamARN=stream_arn)["Consumers"]:
    5151
    if cons["ConsumerARN"] == consumer_arn:
    5252
    return stream_name

    localstack/services/s3/notifications.py

    Lines changed: 28 additions & 15 deletions
    Original file line numberDiff line numberDiff line change
    @@ -33,7 +33,6 @@
    3333
    TopicConfiguration,
    3434
    )
    3535
    from localstack.aws.connect import connect_to
    36-
    from localstack.config import DEFAULT_REGION
    3736
    from localstack.services.s3.models import get_moto_s3_backend
    3837
    from localstack.services.s3.utils import (
    3938
    _create_invalid_argument_exc,
    @@ -93,6 +92,7 @@ class S3EventNotificationContext:
    9392
    request_id: str
    9493
    event_type: str
    9594
    event_time: datetime.datetime
    95+
    account_id: str
    9696
    region: str
    9797
    bucket_name: BucketName
    9898
    key_name: ObjectKey
    @@ -156,6 +156,7 @@ def from_request_context(
    156156
    request_id=request_context.request_id,
    157157
    event_type=EVENT_OPERATION_MAP.get(request_context.operation.wire_name, ""),
    158158
    event_time=datetime.datetime.now(),
    159+
    account_id=request_context.account_id,
    159160
    region=request_context.region,
    160161
    caller=request_context.account_id, # TODO: use it for `userIdentity`
    161162
    bucket_name=bucket_name,
    @@ -205,6 +206,7 @@ def from_request_context_native(
    205206
    request_id=request_context.request_id,
    206207
    event_type=EVENT_OPERATION_MAP.get(request_context.operation.wire_name, ""),
    207208
    event_time=datetime.datetime.now(),
    209+
    account_id=request_context.account_id,
    208210
    region=request_context.region,
    209211
    caller=request_context.account_id, # TODO: use it for `userIdentity`
    210212
    bucket_name=bucket_name,
    @@ -448,7 +450,9 @@ def _get_arn_value_and_name(queue_configuration: QueueConfiguration) -> Tuple[Qu
    448450

    449451
    def _verify_target(self, target_arn: str, verification_ctx: BucketVerificationContext) -> None:
    450452
    arn_data = parse_arn(target_arn)
    451-
    sqs_client = connect_to(region_name=arn_data["region"]).sqs
    453+
    sqs_client = connect_to(
    454+
    aws_access_key_id=arn_data["account"], region_name=arn_data["region"]
    455+
    ).sqs
    452456
    try:
    453457
    queue_url = sqs_client.get_queue_url(
    454458
    QueueName=arn_data["resource"], QueueOwnerAWSAccountId=arn_data["account"]
    @@ -462,7 +466,7 @@ def _verify_target(self, target_arn: str, verification_ctx: BucketVerificationCo
    462466
    )
    463467
    # send test event
    464468
    # https://docs.aws.amazon.com/AmazonS3/latest/userguide/notification-how-to-event-types-and-destinations.html#supported-notification-event-types
    465-
    sqs_client = sqs_client.request_metadata(
    469+
    sqs_client = connect_to().sqs.request_metadata(
    466470
    source_arn=s3_bucket_arn(verification_ctx.bucket_name),
    467471
    service_principal=ServicePrincipal.s3,
    468472
    )
    @@ -487,8 +491,9 @@ def notify(self, ctx: S3EventNotificationContext, config: QueueConfiguration):
    487491
    queue_arn = config["QueueArn"]
    488492

    489493
    parsed_arn = parse_arn(queue_arn)
    490-
    regi 10000 on = parsed_arn["region"]
    491-
    sqs_client = connect_to(region_name=region).sqs.request_metadata(
    494+
    sqs_client = connect_to(
    495+
    aws_access_key_id=parsed_arn["account"], region_name=parsed_arn["region"]
    496+
    ).sqs.request_metadata(
    492497
    source_arn=s3_bucket_arn(ctx.bucket_name), service_principal=ServicePrincipal.s3
    493498
    )
    494499
    try:
    @@ -521,7 +526,9 @@ def _get_arn_value_and_name(topic_configuration: TopicConfiguration) -> [TopicAr
    521526

    522527
    def _verify_target(self, target_arn: str, verification_ctx: BucketVerificationContext) -> None:
    523528
    arn_data = parse_arn(target_arn)
    524-
    sns_client = connect_to(region_name=arn_data["region"]).sns
    529+
    sns_client = connect_to(
    530+
    aws_access_key_id=arn_data["account"], region_name=arn_data["region"]
    531+
    ).sns
    525532
    try:
    526533
    sns_client.get_topic_attributes(TopicArn=target_arn)
    527534
    except ClientError:
    @@ -531,7 +538,7 @@ def _verify_target(self, target_arn: str, verification_ctx: BucketVerificationCo
    531538
    value="The destination topic does not exist",
    532539
    )
    533540

    534-
    sns_client = sns_client.request_metadata(
    541+
    sns_client = connect_to().sns.request_metadata(
    535542
    source_arn=s3_bucket_arn(verification_ctx.bucket_name),
    536543
    service_principal=ServicePrincipal.s3,
    537544
    )
    @@ -566,8 +573,10 @@ def notify(self, ctx: S3EventNotificationContext, config: TopicConfiguration):
    566573
    message = json.dumps(event_payload)
    567574
    topic_arn = config["TopicArn"]
    568575

    569-
    region_name = arns.extract_region_from_arn(topic_arn)
    570-
    sns_client = connect_to(region_name=region_name).sns.request_metadata(
    576+
    arn_data = parse_arn(topic_arn)
    577+
    sns_client = connect_to(
    578+
    aws_access_key_id=arn_data["account"], region_name=arn_data["region"]
    579+
    ).sns.request_metadata(
    571580
    source_arn=s3_bucket_arn(ctx.bucket_name), service_principal=ServicePrincipal.s3
    572581
    )
    573582
    try:
    @@ -595,7 +604,9 @@ def _get_arn_value_and_name(
    595604

    596605
    def _verify_target(self, target_arn: str, verification_ctx: BucketVerificationContext) -> None:
    597606
    arn_data = parse_arn(arn=target_arn)
    598-
    lambda_client = connect_to(region_name=arn_data["region"]).lambda_
    607+
    lambda_client = connect_to(
    608+
    aws_access_key_id=arn_data["account"], region_name=arn_data["region"]
    609+
    ).lambda_
    599610
    try:
    600611
    lambda_client.get_function(FunctionName=target_arn)
    601612
    except ClientError:
    @@ -604,7 +615,7 @@ def _verify_target(self, target_arn: str, verification_ctx: BucketVerificationCo
    604615
    name=target_arn,
    605616
    value="The destination Lambda does not exist",
    606617
    )
    607-
    lambda_client = lambda_client.request_metadata(
    618+
    lambda_client = connect_to().lambda_.request_metadata(
    608619
    source_arn=s3_bucket_arn(verification_ctx.bucket_name),
    609620
    service_principal=ServicePrincipal.s3,
    610621
    )
    @@ -622,8 +633,11 @@ def notify(self, ctx: S3EventNotificationContext, config: LambdaFunctionConfigur
    622633
    payload = json.dumps(event_payload)
    623634
    lambda_arn = config["LambdaFunctionArn"]
    624635

    625-
    region_name = arns.extract_region_from_arn(lambda_arn)
    626-
    lambda_client = connect_to(region_name=region_name).lambda_.request_metadata(
    636+
    arn_data = parse_arn(lambda_arn)
    637+
    638+
    lambda_client = connect_to(
    639+
    aws_access_key_id=arn_data["account"], region_name=arn_data["region"]
    640+
    ).lambda_.request_metadata(
    627641
    source_arn=s3_bucket_arn(ctx.bucket_name), service_principal=ServicePrincipal.s3
    628642
    )
    629643
    lambda_function_config = arns.lambda_function_name(lambda_arn)
    @@ -740,10 +754,9 @@ def _verify_target(self, target_arn: str, verification_ctx: BucketVerificationCo
    740754
    return
    741755

    742756
    def notify(self, ctx: S3EventNotificationContext, config: EventBridgeConfiguration):
    743-
    region = ctx.bucket_location or DEFAULT_REGION
    744757
    # does not require permissions
    745758
    # https://docs.aws.amazon.com/AmazonS3/latest/userguide/ev-permissions.html
    746-
    events_client = connect_to(region_name=region).events
    759+
    events_client = connect_to(aws_access_key_id=ctx.account_id, region_name=ctx.region).events
    747760
    entry = self._get_event_payload(ctx)
    748761
    try:
    749762
    events_client.put_events(Entries=[entry])

    localstack/services/sns/provider.py

    Lines changed: 9 additions & 9 deletions
    Original file line numberDiff line numberDiff line change
    @@ -8,7 +8,6 @@
    88
    from moto.sns.models import MAXIMUM_MESSAGE_LENGTH, SNSBackend, Topic
    99
    from moto.sns.utils import is_e164
    1010

    11-
    from localstack.aws.accounts import get_aws_account_id
    1211
    from localstack.aws.api import CommonServiceException, RequestContext
    1312
    from localstack.aws.api.sns import (
    1413
    AmazonResourceName,
    @@ -48,6 +47,7 @@
    4847
    topicARN,
    4948
    topicName,
    5049
    )
    50+
    from localstack.constants import AWS_REGION_US_EAST_1, DEFAULT_AWS_ACCOUNT_ID
    5151
    from localstack.http import Request, Response, Router, route
    5252
    from localstack.services.edge import ROUTER
    5353
    from localstack.services.moto import call_moto
    @@ -937,8 +937,8 @@ class SNSServicePlatformEndpointMessagesApiResource:
    937937

    938938
    @route(sns_constants.PLATFORM_ENDPOINT_MSGS_ENDPOINT, methods=["GET"])
    939939
    def on_get(self, request: Request):
    940-
    account_id = request.args.get("accountId", get_aws_account_id())
    941-
    region = request.args.get("region", "us-east-1")
    940+
    account_id = request.args.get("accountId", DEFAULT_AWS_ACCOUNT_ID)
    941+
    region = request.args.get("region", AWS_REGION_US_EAST_1)
    942942
    filter_endpoint_arn = request.args.get("endpointArn")
    943943
    store: SnsStore = sns_stores[account_id][region]
    944944
    if filter_endpoint_arn:
    @@ -960,8 +960,8 @@ def on_get(self, request: Request):
    960960

    961961
    @route(sns_constants.PLATFORM_ENDPOINT_MSGS_ENDPOINT, methods=["DELETE"])
    962962
    def on_delete(self, request: Request) -> Response:
    963-
    account_id = request.args.get("accountId", get_aws_account_id())
    964-
    region = request.args.get("region", "us-east-1")
    963+
    account_id = request.args.get("accountId", DEFAULT_AWS_ACCOUNT_ID)
    964+
    region = request.args.get("region", AWS_REGION_US_EAST_1)
    965965
    filter_endpoint_arn = request.args.get("endpointArn")
    966966
    store: SnsStore = sns_stores[account_id][region]
    967967
    if filter_endpoint_arn:
    @@ -996,8 +996,8 @@ class SNSServiceSMSMessagesApiResource:
    996996

    997997
    @route(sns_constants.SMS_MSGS_ENDPOINT, methods=["GET"])
    998998
    def on_get(self, request: Request):
    999-
    account_id = request.args.get("accountId", get_aws_account_id())
    1000-
    region = request.args.get("region", "us-east-1")
    999+
    account_id = request.args.get("accountId", DEFAULT_AWS_ACCOUNT_ID)
    1000+
    region = request.args.get("region", AWS_REGION_US_EAST_1)
    10011001
    filter_phone_number = request.args.get("phoneNumber")
    10021002
    store: SnsStore = sns_stores[account_id][region]
    10031003
    if filter_phone_number:
    @@ -1022,8 +1022,8 @@ def on_get(self, request: Request):
    10221022

    10231023
    @route(sns_constants.SMS_MSGS_ENDPOINT, methods=["DELETE"])
    10241024
    def on_delete(self, request: Request) -> Response:
    1025-
    account_id = request.args.get("accountId", get_aws_account_id())
    1026-
    region = request.args.get("region", "us-east-1")
    1025+
    account_id = request.args.get("accountId", DEFAULT_AWS_ACCOUNT_ID)
    1026+
    region = request.args.get("region", AWS_REGION_US_EAST_1)
    10271027
    filter_phone_number = request.args.get("phoneNumber")
    10281028
    store: SnsStore = sns_stores[account_id][region]
    10291029
    if filter_phone_number:

    localstack/services/sns/publisher.py

    Lines changed: 15 additions & 6 deletions
    Original file line numberDiff line numberDiff line change
    @@ -333,10 +333,13 @@ def _publish(self, context: SnsBatchPublishContext, subscriber: SnsSubscription)
    333333

    334334
    try:
    335335
    queue_url = sqs_queue_url_for_arn(subscriber["Endpoint"])
    336+
    337+
    account_id = extract_account_id_from_arn(subscriber["Endpoint"])
    336338
    region = extract_region_from_arn(subscriber["Endpoint"])
    337-
    sqs_client = connect_to(region_name=region).sqs.request_metadata(
    338-
    source_arn=subscriber["TopicArn"], service_principal="sns"
    339-
    )
    339+
    340+
    sqs_client = connect_to(
    341+
    aws_access_key_id=account_id, region_name=region
    342+
    ).sqs.request_metadata(source_arn=subscriber["TopicArn"], service_principal="sns")
    340343
    response = sqs_client.send_message_batch(QueueUrl=queue_url, Entries=entries)
    341344

    342345
    for message_ctx in context.messages:
    @@ -453,8 +456,9 @@ class EmailJsonTopicPublisher(TopicPublisher):
    453456
    """
    454457

    455458
    def _publish(self, context: SnsPublishContext, subscriber: SnsSubscription):
    459+
    account_id = extract_account_id_from_arn(subscriber["Endpoint"])
    456460
    region = extract_region_from_arn(subscriber["Endpoint"])
    457-
    ses_client = connect_to(region_name=region).ses
    461+
    ses_client = connect_to(aws_access_key_id=account_id, region_name=region).ses
    458462
    if endpoint := subscriber.get("Endpoint"):
    459463
    ses_client.verify_email_address(EmailAddress=endpoint)
    460464
    ses_client.verify_email_address(EmailAddress="admin@localstack.com")
    @@ -598,7 +602,8 @@ def _publish(self, context: SnsPublishContext, subscriber: SnsSubscription):
    598602
    role_arn=role_arn, service_principal=ServicePrincipal.sns, region_name=region
    599603
    )
    600604
    else:
    601-
    factory = connect_to(region_name=region)
    605+
    account_id = extract_account_id_from_arn(subscriber["Endpoint"])
    606+
    factory = connect_to(aws_access_key_id=account_id, region_name=region)
    602607
    firehose_client = factory.firehose.request_metadata(
    603608
    source_arn=subscriber["TopicArn"], service_principal=ServicePrincipal.sns
    604609
    )
    @@ -722,7 +727,11 @@ def get_attributes_for_application_endpoint(endpoint_arn: str) -> Tuple[Dict, Di
    722727
    :param endpoint_arn:
    723728
    :return:
    724729
    """
    725-
    sns_client = connect_to().sns
    730+
    account_id = extract_account_id_from_arn(endpoint_arn)
    731+
    region_name = extract_region_from_arn(endpoint_arn)
    732+
    733+
    sns_client = connect_to(aws_access_key_id=account_id, region_name=region_name).sns
    734+
    726735
    # TODO: we should access this from the moto store directly
    727736
    endpoint_attributes = sns_client.get_endpoint_attributes(EndpointArn=endpoint_arn)
    728737

    localstack/services/sqs/provider.py

    Lines changed: 3 additions & 7 deletions
    Original file line numberDiff line numberDiff line change
    @@ -12,7 +12,6 @@
    1212
    from moto.sqs.models import Message as MotoMessage
    1313

    1414
    from localstack import config
    15-
    from localstack.aws.accounts import get_aws_account_id
    1615
    from localstack.aws.api import CommonServiceException, RequestContext, ServiceException
    1716
    from localstack.aws.api.sqs import (
    1817
    ActionNameList,
    @@ -83,7 +82,6 @@
    8382
    is_message_deduplication_id_required,
    8483
    parse_queue_url,
    8584
    )
    86-
    from localstack.utils.aws import aws_stack
    8785
    from localstack.utils.aws.arns import parse_arn
    8886
    from localstack.utils.aws.request_context import extract_region_from_headers
    8987
    from localstack.utils.cloudwatch.cloudwatch_util import publish_sqs_metric
    @@ -484,9 +482,7 @@ def _get_and_serialize_messages(
    484482
    self, request: Request, region: str, account_id: str, queue_name: str
    485483
    ) -> ReceiveMessageResult:
    486484
    try:
    487-
    store = self.stores[account_id or get_aws_account_id()][
    488-
    region or aws_stack.get_region()
    489-
    ]
    485+
    store = SqsProvider.get_store(account_id, region)
    490486
    queue = store.queues[queue_name]
    491487
    except KeyError:
    492488
    LOG.info(
    @@ -560,8 +556,8 @@ def __init__(self) -> None:
    560556
    self._init_cloudwatch_metrics_reporting()
    561557

    562558
    @staticmethod
    563-
    def get_store(account_id: str = None, region: str = None) -> SqsStore:
    564-
    return sqs_stores[account_id or get_aws_account_id()][region or aws_stack.get_region()]
    559+
    def get_store(account_id: str, region: str) -> SqsStore:
    560+
    return sqs_stores[account_id][region]
    565561

    566562
    def on_before_start(self):
    567563
    self._router_rules = ROUTER.add(SqsDeveloperEndpoints())

    0 commit comments

    Comments
     (0)
    0