From 79ec46e61e52940f2be5d2377b4933a77e5c2b61 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Tue, 29 Aug 2023 18:28:25 +0530 Subject: [PATCH 01/14] [DEBUG] Use non default account and region TO BE REVERTED BEFORE MERGE --- localstack/constants.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/localstack/constants.py b/localstack/constants.py index 8abca4bf08009..547da354d4cf7 100644 --- a/localstack/constants.py +++ b/localstack/constants.py @@ -153,10 +153,10 @@ # Credentials used in the test suite # These can be overridden if the tests are being run against AWS # If a structured access key ID is used, it must correspond to the account ID -TEST_AWS_ACCOUNT_ID = os.getenv("TEST_AWS_ACCOUNT_ID") or DEFAULT_AWS_ACCOUNT_ID -TEST_AWS_ACCESS_KEY_ID = os.getenv("TEST_AWS_ACCESS_KEY_ID") or "test" +TEST_AWS_ACCOUNT_ID = os.getenv("TEST_AWS_ACCOUNT_ID") or "000000000001" +TEST_AWS_ACCESS_KEY_ID = os.getenv("TEST_AWS_ACCESS_KEY_ID") or "000000000001" TEST_AWS_SECRET_ACCESS_KEY = os.getenv("TEST_AWS_SECRET_ACCESS_KEY") or "test" -TEST_AWS_REGION_NAME = os.getenv("TEST_AWS_REGION") or "us-east-1" +TEST_AWS_REGION_NAME = os.getenv("TEST_AWS_REGION") or "us-west-1" # Additional credentials used in the test suite (when running cross-account tests) SECONDARY_TEST_AWS_ACCOUNT_ID = os.getenv("SECONDARY_TEST_AWS_ACCOUNT_ID") or "000000000002" From 5cdc0bf9e3fba94dd99cc904383e68bbca8a61d0 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Tue, 12 Sep 2023 16:54:11 +0530 Subject: [PATCH 02/14] Add account and region context to executions --- .../services/stepfunctions/backend/execution.py | 12 +++++++++++- localstack/services/stepfunctions/provider_v2.py | 2 ++ .../services/stepfunctions/stepfunctions_utils.py | 7 ++++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/localstack/services/stepfunctions/backend/execution.py b/localstack/services/stepfunctions/backend/execution.py index 338915d3f60bc..6117fd4a2989e 100644 --- a/localstack/services/stepfunctions/backend/execution.py +++ b/localstack/services/stepfunctions/backend/execution.py @@ -80,6 +80,10 @@ def terminated(self) -> None: name: Final[str] role_arn: Final[Arn] exec_arn: Final[Arn] + + account_id: str + region_name: str + state_machine: Final[StateMachineInstance] start_date: Final[Timestamp] input_data: Final[Optional[dict]] @@ -102,6 +106,8 @@ def __init__( name: str, role_arn: Arn, exec_arn: Arn, + account_id: str, + region_name: str, state_machine: StateMachineInstance, start_date: Timestamp, input_data: Optional[dict] = None, @@ -110,6 +116,8 @@ def __init__( self.name = name self.role_arn = role_arn self.exec_arn = exec_arn + self.account_id = account_id + self.region_name = region_name self.state_machine = state_machine self.start_date = start_date self.input_data = input_data @@ -122,7 +130,9 @@ def __init__( self.exec_worker = None self.error = None self.cause = None - self._events_client = connect_to().events + self._events_client = connect_to( + aws_access_key_id=self.account_id, region_name=self.region_name + ).events def to_start_output(self) -> StartExecutionOutput: return StartExecutionOutput(executionArn=self.exec_arn, startDate=self.start_date) diff --git a/localstack/services/stepfunctions/provider_v2.py b/localstack/services/stepfunctions/provider_v2.py index 0328f31b02bf8..7936812080ff0 100644 --- a/localstack/services/stepfunctions/provider_v2.py +++ b/localstack/services/stepfunctions/provider_v2.py @@ -351,6 +351,8 @@ def start_execution( name=exec_name, role_arn=state_machine_clone.role_arn, exec_arn=exec_arn, + account_id=context.account_id, + region_name=context.region, state_machine=state_machine_clone, start_date=datetime.datetime.now(), input_data=input_data, diff --git a/localstack/services/stepfunctions/stepfunctions_utils.py b/localstack/services/stepfunctions/stepfunctions_utils.py index 04af050222842..beca5e20b3192 100644 --- a/localstack/services/stepfunctions/stepfunctions_utils.py +++ b/localstack/services/stepfunctions/stepfunctions_utils.py @@ -2,6 +2,7 @@ from typing import Dict from localstack.aws.connect import connect_to +from localstack.utils.aws.arns import parse_arn from localstack.utils.common import retry LOG = logging.getLogger(__name__) @@ -10,7 +11,11 @@ def await_sfn_execution_result(execution_arn: str, timeout_secs: int = 60) -> Dict: """Wait until the given SFN execution ARN is no longer in RUNNING status, then return execution result.""" - client = connect_to().stepfunctions + arn_data = parse_arn(execution_arn) + + client = connect_to( + aws_access_key_id=arn_data["account"], region_name=arn_data["region"] + ).stepfunctions def _get_result(): result = client.describe_execution(executionArn=execution_arn) From cfc974335d26487edd0119a10da2c22964eaad0f Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Tue, 12 Sep 2023 16:59:25 +0530 Subject: [PATCH 03/14] Use execution context for internal DDB client --- .../state_task/service/state_task_service_dynamodb.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py index 48d8273b938b9..17a9a299b656a 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py @@ -124,8 +124,12 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: def _eval_service_task(self, env: Environment, parameters: dict) -> None: api_action = camel_to_snake_case(self.resource.api_action) - - dynamodb_client = connect_to(config=Config(parameter_validation=False)).dynamodb + execution = env.context_object_manager.context_object["Execution"] + dynamodb_client = connect_to( + aws_access_key_id=execution.account_id, + region_name=execution.region_name, + config=Config(parameter_validation=False), + ).dynamodb response = getattr(dynamodb_client, api_action)(**parameters) response.pop("ResponseMetadata", None) env.stack.append(response) From 2682b409404448cb01b67d911f3932e1fa46008c Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Tue, 12 Sep 2023 17:43:49 +0530 Subject: [PATCH 04/14] Build ARNs with proper account ID and region name --- .../stepfunctions/v2/test_stepfunctions_v2.py | 58 +++++++++++++------ 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/tests/aws/services/stepfunctions/v2/test_stepfunctions_v2.py b/tests/aws/services/stepfunctions/v2/test_stepfunctions_v2.py index 80b2db7f0250b..fe8af1d765dc8 100644 --- a/tests/aws/services/stepfunctions/v2/test_stepfunctions_v2.py +++ b/tests/aws/services/stepfunctions/v2/test_stepfunctions_v2.py @@ -7,6 +7,8 @@ from localstack.constants import ( SECONDARY_TEST_AWS_ACCESS_KEY_ID, SECONDARY_TEST_AWS_SECRET_ACCESS_KEY, + TEST_AWS_ACCOUNT_ID, + TEST_AWS_REGION_NAME, ) from localstack.services.events.provider import TEST_EVENTS_CACHE from localstack.testing.pytest import markers @@ -281,10 +283,12 @@ class TestStateMachine: @markers.aws.unknown def test_create_choice_state_machine(self, aws_client): state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_CHOICE) - lambda_arn_4 = arns.lambda_function_arn(TEST_LAMBDA_NAME_4) + lambda_arn_4 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_4, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) definition["States"]["Add"]["Resource"] = lambda_arn_4 definition = json.dumps(definition) sm_name = f"choice-{short_uid()}" @@ -323,9 +327,11 @@ def test_create_run_map_state_machine(self, aws_client): test_output = [{"Hello": name} for name in names] state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_MAP) - lambda_arn_3 = arns.lambda_function_arn(TEST_LAMBDA_NAME_3) + lambda_arn_3 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_3, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) definition["States"]["ExampleMapState"]["ItemProcessor"]["States"]["CallLambda"][ "Resource" ] = lambda_arn_3 @@ -361,10 +367,14 @@ def test_create_run_state_machine(self, aws_client): state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] # create state machine - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_BASIC) - lambda_arn_1 = arns.lambda_function_arn(TEST_LAMBDA_NAME_1) - lambda_arn_2 = arns.lambda_function_arn(TEST_LAMBDA_NAME_2) + lambda_arn_1 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_1, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) + lambda_arn_2 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_2, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) definition["States"]["step1"]["Resource"] = lambda_arn_1 definition["States"]["step2"]["Resource"] = lambda_arn_2 definition = json.dumps(definition) @@ -397,10 +407,14 @@ def test_try_catch_state_machine(self, aws_client): state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] # create state machine - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_CATCH) - lambda_arn_1 = arns.lambda_function_arn(TEST_LAMBDA_NAME_1) - lambda_arn_2 = arns.lambda_function_arn(TEST_LAMBDA_NAME_2) + lambda_arn_1 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_1, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) + lambda_arn_2 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_2, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) definition["States"]["Start"]["Parameters"]["FunctionName"] = lambda_arn_1 definition["States"]["ErrorHandler"]["Resource"] = lambda_arn_2 definition["States"]["Final"]["Resource"] = lambda_arn_2 @@ -431,10 +445,14 @@ def test_intrinsic_functions(self, aws_client): state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] # create state machine - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_INTRINSIC_FUNCS) - lambda_arn_1 = arns.lambda_function_arn(TEST_LAMBDA_NAME_5) - lambda_arn_2 = arns.lambda_function_arn(TEST_LAMBDA_NAME_5) + lambda_arn_1 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_5, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) + lambda_arn_2 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_5, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) if isinstance(definition["States"]["state1"].get("Parameters"), dict): definition["States"]["state1"]["Parameters"]["lambda_params"][ "FunctionName" @@ -480,7 +498,7 @@ def test_events_state_machine(self, aws_client): definition["States"]["step1"]["Parameters"]["Entries"][0]["EventBusName"] = bus_name definition = json.dumps(definition) sm_name = f"events-{short_uid()}" - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) aws_client.stepfunctions.create_state_machine( name=sm_name, definition=definition, roleArn=role_arn ) @@ -517,9 +535,11 @@ def test_create_state_machines_in_parallel(self, cleanups, aws_client): CreateStateMachine operation: Invalid State Machine Definition: ''DUPLICATE_STATE_NAME: Duplicate State name: MissingValue at /States/MissingValue', 'DUPLICATE_STATE_NAME: Duplicate State name: Add at /States/Add'' """ - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_CHOICE) - lambda_arn_4 = arns.lambda_function_arn(TEST_LAMBDA_NAME_4) + lambda_arn_4 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_4, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) definition["States"]["Add"]["Resource"] = lambda_arn_4 definition = json.dumps(definition) results = [] @@ -608,7 +628,7 @@ def test_multiregion_nested(aws_client_factory, region_name, statemachine_defini ) # create state machine child_machine_name = f"sf-child-{short_uid()}" - role = arns.role_arn("sfn_role") + role = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) child_machine_result = client1.create_state_machine( name=child_machine_name, definition=json.dumps(TEST_STATE_MACHINE), roleArn=role ) @@ -616,7 +636,7 @@ def test_multiregion_nested(aws_client_factory, region_name, statemachine_defini # create parent state machine name = f"sf-parent-{short_uid()}" - role = arns.role_arn("sfn_role") + role = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) result = client1.create_state_machine( name=name, definition=json.dumps(statemachine_definition).replace( @@ -778,7 +798,7 @@ def test_run_aws_sdk_secrets_manager(aws_client): state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] # create state machine - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = { "StartAt": "StateCreateSecret", "States": { From 14eb9a2ff4ea428c959b9f713b3ce6f6cc1d0d65 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Tue, 12 Sep 2023 19:55:17 +0530 Subject: [PATCH 05/14] Implement account ID namespacing for legacy stepfunctions provider --- localstack/config.py | 2 - localstack/services/stepfunctions/provider.py | 34 +-- .../stepfunctions/stepfunctions_starter.py | 250 ++++++++++-------- .../legacy/test_stepfunctions_legacy.py | 60 +++-- .../stepfunctions/v2/test_stepfunctions_v2.py | 5 + 5 files changed, 200 insertions(+), 151 deletions(-) diff --git a/localstack/config.py b/localstack/config.py index 001ad6409baef..46edb7a4f217e 100644 --- a/localstack/config.py +++ b/localstack/config.py @@ -941,8 +941,6 @@ def legacy_fallback(envar_name: str, default: T) -> T: # DEV: sbx_user1051 (default when not provided) Alternative system user or empty string to skip dropping privileges. LAMBDA_INIT_USER = os.environ.get("LAMBDA_INIT_USER") -# Adding Stepfunctions default port -LOCAL_PORT_STEPFUNCTIONS = int(os.environ.get("LOCAL_PORT_STEPFUNCTIONS") or 8083) # Stepfunctions lambda endpoint override STEPFUNCTIONS_LAMBDA_ENDPOINT = os.environ.get("STEPFUNCTIONS_LAMBDA_ENDPOINT", "").strip() diff --git a/localstack/services/stepfunctions/provider.py b/localstack/services/stepfunctions/provider.py index 8f7ba23581982..cd225a7d19b81 100644 --- a/localstack/services/stepfunctions/provider.py +++ b/localstack/services/stepfunctions/provider.py @@ -3,6 +3,7 @@ import threading from localstack import config +from localstack.aws.accounts import get_aws_account_id from localstack.aws.api import RequestContext, handler from localstack.aws.api.stepfunctions import ( CreateStateMachineInput, @@ -16,12 +17,9 @@ from localstack.aws.forwarder import get_request_forwarder_http from localstack.constants import LOCALHOST from localstack.services.plugins import ServiceLifecycleHook -from localstack.services.stepfunctions.stepfunctions_starter import ( - start_stepfunctions, - stop_stepfunctions, - wait_for_stepfunctions, -) +from localstack.services.stepfunctions.stepfunctions_starter import StepFunctionsServerManager from localstack.state import AssetDirectory, StateVisitor +from localstack.utils.aws import aws_stack # lock to avoid concurrency issues when creating state machines in parallel (required for StepFunctions-Local) CREATION_LOCK = threading.RLock() @@ -30,33 +28,29 @@ class StepFunctionsProvider(StepfunctionsApi, ServiceLifecycleHook): + server_manager = StepFunctionsServerManager() + def __init__(self): self.forward_request = get_request_forwarder_http(self.get_forward_url) def get_forward_url(self) -> str: """Return the URL of the backend StepFunctions server to forward requests to""" - return f"http://{LOCALHOST}:{config.LOCAL_PORT_STEPFUNCTIONS}" + account_id = get_aws_account_id() + region_name = aws_stack.get_region() + server = self.server_manager.get_server_for_account_region(account_id, region_name) + return f"http://{LOCALHOST}:{server.port}" def accept_state_visitor(self, visitor: StateVisitor): visitor.visit(AssetDirectory(os.path.join(config.dirs.data, self.service))) - def on_before_start(self): - start_stepfunctions() - wait_for_stepfunctions() - - def on_before_state_reset(self): - stop_stepfunctions() - def on_before_state_load(self): - stop_stepfunctions() + self.server_manager.shutdown_all() - def on_after_state_reset(self): - start_stepfunctions() - wait_for_stepfunctions() + def on_before_state_reset(self): + self.server_manager.shutdown_all() - def on_after_state_load(self): - start_stepfunctions() - wait_for_stepfunctions() + def on_before_stop(self): + self.server_manager.shutdown_all() def create_state_machine( self, context: RequestContext, request: CreateStateMachineInput diff --git a/localstack/services/stepfunctions/stepfunctions_starter.py b/localstack/services/stepfunctions/stepfunctions_starter.py index 456f3d55d72d1..bf3ff382d8b38 100644 --- a/localstack/services/stepfunctions/stepfunctions_starter.py +++ b/localstack/services/stepfunctions/stepfunctions_starter.py @@ -1,125 +1,153 @@ import logging -import subprocess +import threading +from typing import Any, Dict from localstack import config -from localstack.aws.accounts import get_aws_account_id -from localstack.aws.connect import connect_to -from localstack.services.infra import do_run, log_startup_message from localstack.services.stepfunctions.packages import stepfunctions_local_package from localstack.utils.aws import aws_stack -from localstack.utils.common import wait_for_port_open -from localstack.utils.net import wait_for_port_closed -from localstack.utils.run import ShellCommandThread, wait_for_process_to_be_killed -from localstack.utils.sync import retry +from localstack.utils.net import get_free_tcp_port +from localstack.utils.run import ShellCommandThread +from localstack.utils.serving import Server +from localstack.utils.threads import TMP_THREADS, FuncThread LOG = logging.getLogger(__name__) # max heap size allocated for the Java process MAX_HEAP_SIZE = "256m" -# todo: will be replaced with plugin mechanism -PROCESS_THREAD: ShellCommandThread | subprocess.Popen | None = None - - -# TODO: pass env more explicitly -def get_command(backend_port): - install_dir_stepfunctions = stepfunctions_local_package.get_installed_dir() - cmd = ( - "cd %s; PORT=%s java " - "-javaagent:aspectjweaver-1.9.7.jar " - "-Dorg.aspectj.weaver.loadtime.configuration=META-INF/aop.xml " - "-Dcom.amazonaws.sdk.disableCertChecking -Xmx%s " - "-jar StepFunctionsLocal.jar --aws-account %s" - ) % ( - install_dir_stepfunctions, - backend_port, - MAX_HEAP_SIZE, - get_aws_account_id(), - ) - if config.STEPFUNCTIONS_LAMBDA_ENDPOINT.lower() != "default": - lambda_endpoint = config.STEPFUNCTIONS_LAMBDA_ENDPOINT or aws_stack.get_local_service_url( - "lambda" + +class StepFunctionsServer(Server): + def __init__( + self, port: int, account_id: str, region_name: str, host: str = "localhost" + ) -> None: + self.account_id = account_id + self.region_name = region_name + super().__init__(port, host) + + def do_start_thread(self) -> FuncThread: + cmd = self.generate_shell_command() + env_vars = self.generate_env_vars() + cwd = stepfunctions_local_package.get_installed_dir() + LOG.debug("Starting StepFunctions process %s with env vars %s", cmd, env_vars) + t = ShellCommandThread( + cmd, + strip_color=True, + env_vars=env_vars, + log_listener=self._log_listener, + name="stepfunctions", + cwd=cwd, ) - cmd += f" --lambda-endpoint {lambda_endpoint}" - # add service endpoint flags - services = [ - "athena", - "batch", - "dynamodb", - "ecs", - "eks", - "events", - "glue", - "sagemaker", - "sns", - "sqs", - "stepfunctions", - ] - for service in services: - flag = f"--{service}-endpoint" - if service == "stepfunctions": - flag = "--step-functions-endpoint" - elif service == "events": - flag = "--eventbridge-endpoint" - elif service in ["athena", "eks"]: - flag = f"--step-functions-{service}" - endpoint = aws_stack.get_local_service_url(service) - cmd += f" {flag} {endpoint}" - - return cmd - - -def start_stepfunctions(asynchronous: bool = True, persistence_path: str | None = None): - # TODO: introduce Server abstraction for StepFunctions process - global PROCESS_THREAD - backend_port = config.LOCAL_PORT_STEPFUNCTIONS - stepfunctions_local_package.install() - cmd = get_command(backend_port) - log_startup_message("StepFunctions") - # TODO: change ports in stepfunctions.jar, then update here - PROCESS_THREAD = do_run( - cmd, - asynchronous, - strip_color=True, - env_vars={ + TMP_THREADS.append(t) + t.start() + return t + + def generate_env_vars(self) -> Dict[str, Any]: + return { "EDGE_PORT": config.EDGE_PORT_HTTP or config.EDGE_PORT, "EDGE_PORT_HTTP": config.EDGE_PORT_HTTP or config.EDGE_PORT, - "DATA_DIR": persistence_path or config.dirs.data, - }, - ) - return PROCESS_THREAD - - -def wait_for_stepfunctions(): - retry(check_stepfunctions, sleep=0.5, retries=15) - - -def stop_stepfunctions(): - if not PROCESS_THREAD or not PROCESS_THREAD.process: - return - LOG.debug("Restarting StepFunctions process ...") - - pid = PROCESS_THREAD.process.pid - PROCESS_THREAD.stop() - wait_for_port_closed(config.LOCAL_PORT_STEPFUNCTIONS, sleep_time=0.5, retries=15) - try: - # TODO: currently failing in CI (potentially due to a defunct process) - need to investigate! - wait_for_process_to_be_killed(pid, sleep=0.3, retries=10) - except Exception as e: - LOG.warning("StepFunctions process not properly terminated: %s", e) - - -def check_stepfunctions(expect_shutdown: bool = False, print_error: bool = False) -> None: - out = None - try: - wait_for_port_open(config.LOCAL_PORT_STEPFUNCTIONS, sleep_time=2) - endpoint_url = f"http://127.0.0.1:{config.LOCAL_PORT_STEPFUNCTIONS}" - out = connect_to(endpoint_url=endpoint_url).stepfunctions.list_state_machines() - except Exception: - if print_error: - LOG.exception("StepFunctions health check failed") - - if expect_shutdown: - assert out is None - else: - assert out and isinstance(out.get("stateMachines"), list) + "DATA_DIR": config.dirs.data, + "PORT": self._port, + } + + def generate_shell_command(self) -> str: + cmd = ( + "java " + "-javaagent:aspectjweaver-1.9.7.jar " + "-Dorg.aspectj.weaver.loadtime.configuration=META-INF/aop.xml " + "-Dcom.amazonaws.sdk.disableCertChecking " + "-Xmx%s " + "-jar StepFunctionsLocal.jar " + "--aws-account %s " + "--region %s" + ) % ( + MAX_HEAP_SIZE, + self.account_id, + self.region_name, + ) + + if config.STEPFUNCTIONS_LAMBDA_ENDPOINT.lower() != "default": + lambda_endpoint = ( + config.STEPFUNCTIONS_LAMBDA_ENDPOINT or aws_stack.get_local_service_url("lambda") + ) + cmd += f" --lambda-endpoint {lambda_endpoint}" + + # add service endpoint flags + services = [ + "athena", + "batch", + "dynamodb", + "ecs", + "eks", + "events", + "glue", + "sagemaker", + "sns", + "sqs", + "stepfunctions", + ] + + for service in services: + flag = f"--{service}-endpoint" + if service == "stepfunctions": + flag = "--step-functions-endpoint" + elif service == "events": + flag = "--eventbridge-endpoint" + elif service in ["athena", "eks"]: + flag = f"--step-functions-{service}" + endpoint = aws_stack.get_local_service_url(service) + cmd += f" {flag} {endpoint}" + + return cmd + + def _log_listener(self, line, **kwargs): + LOG.debug(line.rstrip()) + + +class StepFunctionsServerManager: + default_startup_timeout = 20 + + def __init__(self): + self._lock = threading.RLock() + self._servers: dict[tuple[str, str], StepFunctionsServer] = {} + + def get_server_for_account_region( + self, account_id: str, region_name: str + ) -> StepFunctionsServer: + locator = (account_id, region_name) + + if locator in self._servers: + return self._servers[locator] + + with self._lock: + if locator in self._servers: + return self._servers[locator] + + LOG.info("Creating StepFunctions server for %s", locator) + self._servers[locator] = self._create_stepfunctions_server(account_id, region_name) + + self._servers[locator].start() + + if not self._servers[locator].wait_is_up(timeout=self.default_startup_timeout): + raise TimeoutError("Gave up waiting for StepFunctions server to start up") + + return self._servers[locator] + + def shutdown_all(self): + with self._lock: + while self._servers: + locator, server = self._servers.popitem() + LOG.info("Shutting down StepFunctions for %s", locator) + server.shutdown() + + def _create_stepfunctions_server( + self, account_id: str, region_name: str + ) -> StepFunctionsServer: + port = get_free_tcp_port() + stepfunctions_local_package.install() + + server = StepFunctionsServer( + port=port, + account_id=account_id, + region_name=region_name, + ) + return server diff --git a/tests/aws/services/stepfunctions/legacy/test_stepfunctions_legacy.py b/tests/aws/services/stepfunctions/legacy/test_stepfunctions_legacy.py index 961e7bd4962fa..dd8badd8a0a09 100644 --- a/tests/aws/services/stepfunctions/legacy/test_stepfunctions_legacy.py +++ b/tests/aws/services/stepfunctions/legacy/test_stepfunctions_legacy.py @@ -4,6 +4,7 @@ import pytest +from localstack.constants import TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME from localstack.services.events.provider import TEST_EVENTS_CACHE from localstack.services.stepfunctions.stepfunctions_utils import await_sfn_execution_result from localstack.testing.pytest import markers @@ -184,29 +185,34 @@ def setup_and_tear_down(aws_client): zip_file=zip_file, envvars={"Hello": TEST_RESULT_VALUE}, client=aws_client.lambda_, + s3_client=aws_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_2, zip_file=zip_file, envvars={"Hello": TEST_RESULT_VALUE_2}, client=aws_client.lambda_, + s3_client=aws_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_3, zip_file=zip_file, envvars={"Hello": "Replace Value"}, client=aws_client.lambda_, + s3_client=aws_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_4, zip_file=zip_file, envvars={"Hello": TEST_RESULT_VALUE_4}, client=aws_client.lambda_, + s3_client=aws_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_5, zip_file=zip_file2, client=aws_client.lambda_, + s3_client=aws_client.s3, ) active_waiter = lambda_client.get_waiter("function_active_v2") @@ -282,10 +288,12 @@ class TestStateMachine: @markers.aws.unknown def test_create_choice_state_machine(self, aws_client): state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_CHOICE) - lambda_arn_4 = arns.lambda_function_arn(TEST_LAMBDA_NAME_4) + lambda_arn_4 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_4, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) definition["States"]["Add"]["Resource"] = lambda_arn_4 definition = json.dumps(definition) sm_name = f"choice-{short_uid()}" @@ -324,9 +332,11 @@ def test_create_run_map_state_machine(self, aws_client): test_output = [{"Hello": name} for name in names] state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_MAP) - lambda_arn_3 = arns.lambda_function_arn(TEST_LAMBDA_NAME_3) + lambda_arn_3 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_3, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) definition["States"]["ExampleMapState"]["Iterator"]["States"]["CallLambda"][ "Resource" ] = lambda_arn_3 @@ -362,10 +372,14 @@ def test_create_run_state_machine(self, aws_client): state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] # create state machine - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_BASIC) - lambda_arn_1 = arns.lambda_function_arn(TEST_LAMBDA_NAME_1) - lambda_arn_2 = arns.lambda_function_arn(TEST_LAMBDA_NAME_2) + lambda_arn_1 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_1, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) + lambda_arn_2 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_2, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) definition["States"]["step1"]["Resource"] = lambda_arn_1 definition["States"]["step2"]["Resource"] = lambda_arn_2 definition = json.dumps(definition) @@ -401,10 +415,14 @@ def test_try_catch_state_machine(self, aws_client): state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] # create state machine - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_CATCH) - lambda_arn_1 = arns.lambda_function_arn(TEST_LAMBDA_NAME_1) - lambda_arn_2 = arns.lambda_function_arn(TEST_LAMBDA_NAME_2) + lambda_arn_1 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_1, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) + lambda_arn_2 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_2, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) definition["States"]["Start"]["Parameters"]["FunctionName"] = lambda_arn_1 definition["States"]["ErrorHandler"]["Resource"] = lambda_arn_2 definition["States"]["Final"]["Resource"] = lambda_arn_2 @@ -439,10 +457,14 @@ def test_intrinsic_functions(self, aws_client): state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] # create state machine - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_INTRINSIC_FUNCS) - lambda_arn_1 = arns.lambda_function_arn(TEST_LAMBDA_NAME_5) - lambda_arn_2 = arns.lambda_function_arn(TEST_LAMBDA_NAME_5) + lambda_arn_1 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_5, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) + lambda_arn_2 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_5, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) if isinstance(definition["States"]["state1"].get("Parameters"), dict): definition["States"]["state1"]["Parameters"]["lambda_params"][ "FunctionName" @@ -487,7 +509,7 @@ def test_events_state_machine(self, aws_client): definition["States"]["step1"]["Parameters"]["Entries"][0]["EventBusName"] = bus_name definition = json.dumps(definition) sm_name = f"events-{short_uid()}" - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) aws_client.stepfunctions.create_state_machine( name=sm_name, definition=definition, roleArn=role_arn ) @@ -524,9 +546,11 @@ def test_create_state_machines_in_parallel(self, cleanups, aws_client): CreateStateMachine operation: Invalid State Machine Definition: ''DUPLICATE_STATE_NAME: Duplicate State name: MissingValue at /States/MissingValue', 'DUPLICATE_STATE_NAME: Duplicate State name: Add at /States/Add'' """ - role_arn = arns.role_arn("sfn_role") + role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_CHOICE) - lambda_arn_4 = arns.lambda_function_arn(TEST_LAMBDA_NAME_4) + lambda_arn_4 = arns.lambda_function_arn( + TEST_LAMBDA_NAME_4, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + ) definition["States"]["Add"]["Resource"] = lambda_arn_4 definition = json.dumps(definition) results = [] @@ -609,7 +633,7 @@ def test_multiregion_nested(aws_client_factory, region_name, statemachine_defini client1 = aws_client_factory(region_name=region_name).stepfunctions # create state machine child_machine_name = f"sf-child-{short_uid()}" - role = arns.role_arn("sfn_role") + role = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) child_machine_result = client1.create_state_machine( name=child_machine_name, definition=json.dumps(TEST_STATE_MACHINE), roleArn=role ) @@ -617,7 +641,7 @@ def test_multiregion_nested(aws_client_factory, region_name, statemachine_defini # create parent state machine name = f"sf-parent-{short_uid()}" - role = arns.role_arn("sfn_role") + role = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) result = client1.create_state_machine( name=name, definition=json.dumps(statemachine_definition).replace( diff --git a/tests/aws/services/stepfunctions/v2/test_stepfunctions_v2.py b/tests/aws/services/stepfunctions/v2/test_stepfunctions_v2.py index fe8af1d765dc8..d60cd6e118d60 100644 --- a/tests/aws/services/stepfunctions/v2/test_stepfunctions_v2.py +++ b/tests/aws/services/stepfunctions/v2/test_stepfunctions_v2.py @@ -195,29 +195,34 @@ def setup_and_tear_down(aws_client): zip_file=zip_file, envvars={"Hello": TEST_RESULT_VALUE}, client=aws_client.lambda_, + s3_client=aws_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_2, zip_file=zip_file, envvars={"Hello": TEST_RESULT_VALUE_2}, client=aws_client.lambda_, + s3_client=aws_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_3, zip_file=zip_file, envvars={"Hello": "Replace Value"}, client=aws_client.lambda_, + s3_client=aws_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_4, zip_file=zip_file, envvars={"Hello": TEST_RESULT_VALUE_4}, client=aws_client.lambda_, + s3_client=aws_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_5, zip_file=zip_file2, client=aws_client.lambda_, + s3_client=aws_client.s3, ) active_waiter = lambda_client.get_waiter("function_active_v2") From f4c8a85cbed64a6a36ec71a1d52bb366aad0121d Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Wed, 13 Sep 2023 16:44:45 +0530 Subject: [PATCH 06/14] Take account ID and region into consideration for internal requests --- .../state_task/lambda_eval_utils.py | 6 ++---- .../service/state_task_service_aws_sdk.py | 7 ++----- .../service/state_task_service_dynamodb.py | 10 ++-------- .../state_task/service/state_task_service_sfn.py | 12 +++--------- .../state_task/service/state_task_service_sns.py | 5 ++--- .../state_task/service/state_task_service_sqs.py | 5 ++--- .../asl/eval/contextobject/contex_object.py | 2 ++ .../services/stepfunctions/backend/execution.py | 2 ++ localstack/services/stepfunctions/backend/utils.py | 14 ++++++++++++++ 9 files changed, 31 insertions(+), 32 deletions(-) create mode 100644 localstack/services/stepfunctions/backend/utils.py diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py index 1dd03cb568da4..6439b1783cbd3 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py @@ -2,12 +2,10 @@ from json import JSONDecodeError from typing import Any, Final, Optional -from botocore.config import Config - from localstack.aws.api.lambda_ import InvocationResponse -from localstack.aws.connect import connect_externally_to from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.utils.encoding import to_json_str +from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.collections import select_from_typed_dict from localstack.utils.run import to_str from localstack.utils.strings import to_bytes @@ -23,7 +21,7 @@ def __init__(self, function_error: Optional[str], payload: str): def exec_lambda_function(env: Environment, parameters: dict) -> None: - lambda_client = connect_externally_to(config=Config(parameter_validation=False)).lambda_ + lambda_client = get_boto_client(env, "lambda") invocation_resp: InvocationResponse = lambda_client.invoke(**parameters) func_error: Optional[str] = invocation_resp.get("FunctionError") diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py index 148d8aba7bceb..9c009f5a95fcb 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py @@ -1,8 +1,6 @@ -from botocore.config import Config from botocore.exceptions import ClientError from localstack.aws.api.stepfunctions import HistoryEventType, TaskFailedEventDetails -from localstack.aws.connect import connect_externally_to from localstack.aws.protocol.service_router import get_service_catalog from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import ( FailureEvent, @@ -19,6 +17,7 @@ from localstack.services.stepfunctions.asl.component.state.state_props import StateProps from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails +from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.common import camel_to_snake_case @@ -105,9 +104,7 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: return super()._from_error(env=env, ex=ex) def _eval_service_task(self, env: Environment, parameters: dict) -> None: - api_client = connect_externally_to.get_client( - service_name=self._normalised_api_name, config=Config(parameter_validation=False) - ) + api_client = get_boto_client(env, self._normalised_api_name) response = getattr(api_client, self._normalised_api_action)(**parameters) or dict() if response: response.pop("ResponseMetadata", None) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py index 17a9a299b656a..a2f0309a299bd 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py @@ -1,10 +1,8 @@ from typing import Final, Optional -from botocore.config import Config from botocore.exceptions import ClientError from localstack.aws.api.stepfunctions import HistoryEventType, TaskFailedEventDetails -from localstack.aws.connect import connect_to from localstack.services.stepfunctions.asl.component.common.error_name.custom_error_name import ( CustomErrorName, ) @@ -16,6 +14,7 @@ ) from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails +from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.strings import camel_to_snake_case @@ -124,12 +123,7 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: def _eval_service_task(self, env: Environment, parameters: dict) -> None: api_action = camel_to_snake_case(self.resource.api_action) - execution = env.context_object_manager.context_object["Execution"] - dynamodb_client = connect_to( - aws_access_key_id=execution.account_id, - region_name=execution.region_name, - config=Config(parameter_validation=False), - ).dynamodb + dynamodb_client = get_boto_client(env, "dynamodb") response = getattr(dynamodb_client, api_action)(**parameters) response.pop("ResponseMetadata", None) env.stack.append(response) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py index bc93a9f31ed71..a139a1ac7feaf 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py @@ -1,7 +1,6 @@ import json from typing import Any, Final, Optional -from botocore.config import Config from botocore.exceptions import ClientError from localstack.aws.api.stepfunctions import ( @@ -10,7 +9,6 @@ HistoryEventType, TaskFailedEventDetails, ) -from localstack.aws.connect import connect_externally_to from localstack.services.stepfunctions.asl.component.common.error_name.custom_error_name import ( CustomErrorName, ) @@ -30,6 +28,7 @@ from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails from localstack.services.stepfunctions.asl.utils.encoding import to_json_str +from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.collections import select_from_typed_dict from localstack.utils.strings import camel_to_snake_case @@ -100,10 +99,6 @@ def _apply_normalisation(lookup_keys, dictionary): lower_to_normalise_key = _build_lower_to_key_dict(keys) _apply_normalisation(lower_to_normalise_key, response) - @staticmethod - def _get_sfn_client(): - return connect_externally_to(config=Config(parameter_validation=False)).stepfunctions - def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: if isinstance(ex, ClientError): error_code = ex.response["Error"]["Code"] @@ -166,15 +161,14 @@ def _replace_with_json_if_str(key: str) -> None: def _eval_service_task(self, env: Environment, parameters: dict) -> None: api_action = camel_to_snake_case(self.resource.api_action) - sfn_client = self._get_sfn_client() + sfn_client = get_boto_client(env, "stepfunctions") response = getattr(sfn_client, api_action)(**parameters) response.pop("ResponseMetadata", None) self._normalise_botocore_response(self.resource.api_action, response) env.stack.append(response) def _sync_to_start_machine(self, env: Environment, sync2_response: bool) -> None: - sfn_client = self._get_sfn_client() - + sfn_client = get_boto_client(env, "stepfunctions") submission_output: dict = env.stack.pop() execution_arn: str = submission_output["ExecutionArn"] diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py index 3f5b79741fe08..3d8d59f0c4aa5 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py @@ -1,10 +1,8 @@ from typing import Final, Optional -from botocore.config import Config from botocore.exceptions import ClientError from localstack.aws.api.stepfunctions import HistoryEventType, TaskFailedEventDetails -from localstack.aws.connect import connect_externally_to from localstack.services.stepfunctions.asl.component.common.error_name.custom_error_name import ( CustomErrorName, ) @@ -16,6 +14,7 @@ ) from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails +from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.strings import camel_to_snake_case @@ -72,7 +71,7 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: def _eval_service_task(self, env: Environment, parameters: dict) -> None: api_action = camel_to_snake_case(self.resource.api_action) - sns_client = connect_externally_to(config=Config(parameter_validation=False)).sns + sns_client = get_boto_client(env, "sns") response = getattr(sns_client, api_action)(**parameters) response.pop("ResponseMetadata", None) env.stack.append(response) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py index 9c207afc24b90..ae9dd10860c39 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py @@ -1,10 +1,8 @@ from typing import Final, Optional -from botocore.config import Config from botocore.exceptions import ClientError from localstack.aws.api.stepfunctions import HistoryEventType, TaskFailedEventDetails -from localstack.aws.connect import connect_externally_to from localstack.services.stepfunctions.asl.component.common.error_name.custom_error_name import ( CustomErrorName, ) @@ -17,6 +15,7 @@ from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails from localstack.services.stepfunctions.asl.utils.encoding import to_json_str +from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.strings import camel_to_snake_case @@ -65,7 +64,7 @@ def _eval_service_task(self, env: Environment, parameters: dict) -> None: parameters["MessageBody"] = to_json_str(message_body) api_action = camel_to_snake_case(self.resource.api_action) - sqs_client = connect_externally_to(config=Config(parameter_validation=False)).sqs + sqs_client = get_boto_client(env, "sqs") response = getattr(sqs_client, api_action)(**parameters) response.pop("ResponseMetadata", None) env.stack.append(response) diff --git a/localstack/services/stepfunctions/asl/eval/contextobject/contex_object.py b/localstack/services/stepfunctions/asl/eval/contextobject/contex_object.py index 33fbadc30cf11..ec44f93b73573 100644 --- a/localstack/services/stepfunctions/asl/eval/contextobject/contex_object.py +++ b/localstack/services/stepfunctions/asl/eval/contextobject/contex_object.py @@ -9,6 +9,8 @@ class Execution(TypedDict): Name: str RoleArn: str StartTime: str # Format: ISO 8601. + AccountId: str + RegionName: str class State(TypedDict): diff --git a/localstack/services/stepfunctions/backend/execution.py b/localstack/services/stepfunctions/backend/execution.py index 6117fd4a2989e..08fcc6c2846a7 100644 --- a/localstack/services/stepfunctions/backend/execution.py +++ b/localstack/services/stepfunctions/backend/execution.py @@ -221,6 +221,8 @@ def start(self) -> None: Name=self.name, RoleArn=self.role_arn, StartTime=self.start_date.time().isoformat(), + AccountId=self.account_id, + RegionName=self.region_name, ), StateMachine=ContextObjectStateMachine( Id=self.state_machine.arn, diff --git a/localstack/services/stepfunctions/backend/utils.py b/localstack/services/stepfunctions/backend/utils.py new file mode 100644 index 0000000000000..b54c2bd4dc501 --- /dev/null +++ b/localstack/services/stepfunctions/backend/utils.py @@ -0,0 +1,14 @@ +from botocore.config import Config + +from localstack.aws.connect import connect_to +from localstack.services.stepfunctions.asl.eval.environment import Environment + + +def get_boto_client(env: Environment, service: str): + execution = env.context_object_manager.context_object["Execution"] + return connect_to.get_client( + aws_access_key_id=execution["AccountId"], + region_name=execution["RegionName"], + service_name=service, + config=Config(parameter_validation=False), + ) From 9cf16d01901a531ca90a2bec3cce1b9c18f3efdb Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Wed, 13 Sep 2023 17:49:53 +0530 Subject: [PATCH 07/14] Run the legacy SF tests in a hardcoded region --- .../stepfunctions/stepfunctions_starter.py | 20 +- .../legacy/test_stepfunctions_legacy.py | 239 ++++++++++-------- 2 files changed, 136 insertions(+), 123 deletions(-) diff --git a/localstack/services/stepfunctions/stepfunctions_starter.py b/localstack/services/stepfunctions/stepfunctions_starter.py index bf3ff382d8b38..b0f10534631f9 100644 --- a/localstack/services/stepfunctions/stepfunctions_starter.py +++ b/localstack/services/stepfunctions/stepfunctions_starter.py @@ -51,18 +51,14 @@ def generate_env_vars(self) -> Dict[str, Any]: def generate_shell_command(self) -> str: cmd = ( - "java " - "-javaagent:aspectjweaver-1.9.7.jar " - "-Dorg.aspectj.weaver.loadtime.configuration=META-INF/aop.xml " - "-Dcom.amazonaws.sdk.disableCertChecking " - "-Xmx%s " - "-jar StepFunctionsLocal.jar " - "--aws-account %s " - "--region %s" - ) % ( - MAX_HEAP_SIZE, - self.account_id, - self.region_name, + f"java " + f"-javaagent:aspectjweaver-1.9.7.jar " + f"-Dorg.aspectj.weaver.loadtime.configuration=META-INF/aop.xml " + f"-Dcom.amazonaws.sdk.disableCertChecking " + f"-Xmx{MAX_HEAP_SIZE} " + f"-jar StepFunctionsLocal.jar " + f"--aws-account {self.account_id} " + f"--aws-region {self.region_name} " ) if config.STEPFUNCTIONS_LAMBDA_ENDPOINT.lower() != "default": diff --git a/tests/aws/services/stepfunctions/legacy/test_stepfunctions_legacy.py b/tests/aws/services/stepfunctions/legacy/test_stepfunctions_legacy.py index dd8badd8a0a09..f7939b645b39b 100644 --- a/tests/aws/services/stepfunctions/legacy/test_stepfunctions_legacy.py +++ b/tests/aws/services/stepfunctions/legacy/test_stepfunctions_legacy.py @@ -4,7 +4,7 @@ import pytest -from localstack.constants import TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME +from localstack.constants import TEST_AWS_REGION_NAME from localstack.services.events.provider import TEST_EVENTS_CACHE from localstack.services.stepfunctions.stepfunctions_utils import await_sfn_execution_result from localstack.testing.pytest import markers @@ -173,10 +173,23 @@ LOG = logging.getLogger(__name__) +# The legacy StepFunctions provider does not properly support multi-accounts +# Although StepFunctions Local has an `--account-id` argument, +# it does not obey the override especially during Lambda invocations. +# As such, the tests in this module only run for the following account. +SF_TEST_AWS_ACCOUNT_ID = "000000000000" + + +@pytest.fixture(scope="module") +def custom_client(aws_client_factory): + return aws_client_factory( + region_name=TEST_AWS_REGION_NAME, aws_access_key_id=SF_TEST_AWS_ACCOUNT_ID + ) + @pytest.fixture(scope="module") -def setup_and_tear_down(aws_client): - lambda_client = aws_client.lambda_ +def setup_and_tear_down(custom_client): + lambda_client = custom_client.lambda_ zip_file = testutil.create_lambda_archive(load_file(TEST_LAMBDA_ENV), get_content=True) zip_file2 = testutil.create_lambda_archive(load_file(TEST_LAMBDA_PYTHON_ECHO), get_content=True) @@ -184,35 +197,35 @@ def setup_and_tear_down(aws_client): func_name=TEST_LAMBDA_NAME_1, zip_file=zip_file, envvars={"Hello": TEST_RESULT_VALUE}, - client=aws_client.lambda_, - s3_client=aws_client.s3, + client=custom_client.lambda_, + s3_client=custom_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_2, zip_file=zip_file, envvars={"Hello": TEST_RESULT_VALUE_2}, - client=aws_client.lambda_, - s3_client=aws_client.s3, + client=custom_client.lambda_, + s3_client=custom_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_3, zip_file=zip_file, envvars={"Hello": "Replace Value"}, - client=aws_client.lambda_, - s3_client=aws_client.s3, + client=custom_client.lambda_, + s3_client=custom_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_4, zip_file=zip_file, envvars={"Hello": TEST_RESULT_VALUE_4}, - client=aws_client.lambda_, - s3_client=aws_client.s3, + client=custom_client.lambda_, + s3_client=custom_client.s3, ) testutil.create_lambda_function( func_name=TEST_LAMBDA_NAME_5, zip_file=zip_file2, - client=aws_client.lambda_, - s3_client=aws_client.s3, + client=custom_client.lambda_, + s3_client=custom_client.s3, ) active_waiter = lambda_client.get_waiter("function_active_v2") @@ -224,17 +237,17 @@ def setup_and_tear_down(aws_client): yield - aws_client.lambda_.delete_function(FunctionName=TEST_LAMBDA_NAME_1) - aws_client.lambda_.delete_function(FunctionName=TEST_LAMBDA_NAME_2) - aws_client.lambda_.delete_function(FunctionName=TEST_LAMBDA_NAME_3) - aws_client.lambda_.delete_function(FunctionName=TEST_LAMBDA_NAME_4) - aws_client.lambda_.delete_function(FunctionName=TEST_LAMBDA_NAME_5) + custom_client.lambda_.delete_function(FunctionName=TEST_LAMBDA_NAME_1) + custom_client.lambda_.delete_function(FunctionName=TEST_LAMBDA_NAME_2) + custom_client.lambda_.delete_function(FunctionName=TEST_LAMBDA_NAME_3) + custom_client.lambda_.delete_function(FunctionName=TEST_LAMBDA_NAME_4) + custom_client.lambda_.delete_function(FunctionName=TEST_LAMBDA_NAME_5) @pytest.fixture -def sfn_execution_role(aws_client): +def sfn_execution_role(custom_client): role_name = f"role-{short_uid()}" - result = aws_client.iam.create_role( + result = custom_client.iam.create_role( RoleName=role_name, AssumeRolePolicyDocument='{"Version": "2012-10-17", "Statement": {"Action": "sts:AssumeRole", "Effect": "Allow", "Principal": {"Service": "states.amazonaws.com"}}}', ) @@ -286,28 +299,28 @@ def get_machine_arn(sm_name, sfn_client): @pytest.mark.usefixtures("setup_and_tear_down") class TestStateMachine: @markers.aws.unknown - def test_create_choice_state_machine(self, aws_client): - state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] - role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) + def test_create_choice_state_machine(self, custom_client): + state_machines_before = custom_client.stepfunctions.list_state_machines()["stateMachines"] + role_arn = arns.role_arn("sfn_role", SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_CHOICE) lambda_arn_4 = arns.lambda_function_arn( - TEST_LAMBDA_NAME_4, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + TEST_LAMBDA_NAME_4, SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME ) definition["States"]["Add"]["Resource"] = lambda_arn_4 definition = json.dumps(definition) sm_name = f"choice-{short_uid()}" - aws_client.stepfunctions.create_state_machine( + custom_client.stepfunctions.create_state_machine( name=sm_name, definition=definition, roleArn=role_arn ) # assert that the SM has been created - assert_machine_created(state_machines_before, aws_client.stepfunctions) + assert_machine_created(state_machines_before, custom_client.stepfunctions) # run state machine - sm_arn = get_machine_arn(sm_name, aws_client.stepfunctions) + sm_arn = get_machine_arn(sm_name, custom_client.stepfunctions) input = {"x": "1", "y": "2"} - result = aws_client.stepfunctions.start_execution( + result = custom_client.stepfunctions.start_execution( stateMachineArn=sm_arn, input=json.dumps(input) ) assert result.get("executionArn") @@ -316,154 +329,154 @@ def test_create_choice_state_machine(self, aws_client): test_output = {**input, "added": {"Hello": TEST_RESULT_VALUE_4}} def check_result(): - result = _get_execution_results(sm_arn, aws_client.stepfunctions) + result = _get_execution_results(sm_arn, custom_client.stepfunctions) assert test_output == result # assert that the result is correct retry(check_result, sleep=2, retries=10) # clean up - cleanup(sm_arn, state_machines_before, sfn_client=aws_client.stepfunctions) + cleanup(sm_arn, state_machines_before, sfn_client=custom_client.stepfunctions) @markers.aws.unknown - def test_create_run_map_state_machine(self, aws_client): + def test_create_run_map_state_machine(self, custom_client): names = ["Bob", "Meg", "Joe"] test_input = [{"map": name} for name in names] test_output = [{"Hello": name} for name in names] - state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] + state_machines_before = custom_client.stepfunctions.list_state_machines()["stateMachines"] - role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) + role_arn = arns.role_arn("sfn_role", SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_MAP) lambda_arn_3 = arns.lambda_function_arn( - TEST_LAMBDA_NAME_3, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + TEST_LAMBDA_NAME_3, SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME ) definition["States"]["ExampleMapState"]["Iterator"]["States"]["CallLambda"][ "Resource" ] = lambda_arn_3 definition = json.dumps(definition) sm_name = f"map-{short_uid()}" - aws_client.stepfunctions.create_state_machine( + custom_client.stepfunctions.create_state_machine( name=sm_name, definition=definition, roleArn=role_arn ) # assert that the SM has been created - assert_machine_created(state_machines_before, aws_client.stepfunctions) + assert_machine_created(state_machines_before, custom_client.stepfunctions) # run state machine - sm_arn = get_machine_arn(sm_name, aws_client.stepfunctions) - result = aws_client.stepfunctions.start_execution( + sm_arn = get_machine_arn(sm_name, custom_client.stepfunctions) + result = custom_client.stepfunctions.start_execution( stateMachineArn=sm_arn, input=json.dumps(test_input) ) assert result.get("executionArn") def check_invocations(): # assert that the result is correct - result = _get_execution_results(sm_arn, aws_client.stepfunctions) + result = _get_execution_results(sm_arn, custom_client.stepfunctions) assert test_output == result # assert that the lambda has been invoked by the SM execution retry(check_invocations, sleep=1, retries=10) # clean up - cleanup(sm_arn, state_machines_before, aws_client.stepfunctions) + cleanup(sm_arn, state_machines_before, custom_client.stepfunctions) @markers.aws.unknown - def test_create_run_state_machine(self, aws_client): - state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] + def test_create_run_state_machine(self, custom_client): + state_machines_before = custom_client.stepfunctions.list_state_machines()["stateMachines"] # create state machine - role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) + role_arn = arns.role_arn("sfn_role", SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_BASIC) lambda_arn_1 = arns.lambda_function_arn( - TEST_LAMBDA_NAME_1, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + TEST_LAMBDA_NAME_1, SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME ) lambda_arn_2 = arns.lambda_function_arn( - TEST_LAMBDA_NAME_2, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + TEST_LAMBDA_NAME_2, SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME ) definition["States"]["step1"]["Resource"] = lambda_arn_1 definition["States"]["step2"]["Resource"] = lambda_arn_2 definition = json.dumps(definition) sm_name = f"basic-{short_uid()}" - aws_client.stepfunctions.create_state_machine( + custom_client.stepfunctions.create_state_machine( name=sm_name, definition=definition, roleArn=role_arn ) # assert that the SM has been created - assert_machine_created(state_machines_before, aws_client.stepfunctions) + assert_machine_created(state_machines_before, custom_client.stepfunctions) # run state machine - sm_arn = get_machine_arn(sm_name, aws_client.stepfunctions) - result = aws_client.stepfunctions.start_execution(stateMachineArn=sm_arn) + sm_arn = get_machine_arn(sm_name, custom_client.stepfunctions) + result = custom_client.stepfunctions.start_execution(stateMachineArn=sm_arn) assert result.get("executionArn") def check_invocations(): # assert that the result is correct - result = _get_execution_results(sm_arn, aws_client.stepfunctions) + result = _get_execution_results(sm_arn, custom_client.stepfunctions) assert {"Hello": TEST_RESULT_VALUE_2} == result["result_value"] # assert that the lambda has been invoked by the SM execution retry(check_invocations, sleep=0.7, retries=25) # clean up - cleanup(sm_arn, state_machines_before, aws_client.stepfunctions) + cleanup(sm_arn, state_machines_before, custom_client.stepfunctions) @markers.aws.unknown - def test_try_catch_state_machine(self, aws_client): + def test_try_catch_state_machine(self, custom_client): if os.environ.get("AWS_DEFAULT_REGION") != "us-east-1": pytest.skip("skipping non us-east-1 temporarily") - state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] + state_machines_before = custom_client.stepfunctions.list_state_machines()["stateMachines"] # create state machine - role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) + role_arn = arns.role_arn("sfn_role", SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_CATCH) lambda_arn_1 = arns.lambda_function_arn( - TEST_LAMBDA_NAME_1, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + TEST_LAMBDA_NAME_1, SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME ) lambda_arn_2 = arns.lambda_function_arn( - TEST_LAMBDA_NAME_2, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + TEST_LAMBDA_NAME_2, SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME ) definition["States"]["Start"]["Parameters"]["FunctionName"] = lambda_arn_1 definition["States"]["ErrorHandler"]["Resource"] = lambda_arn_2 definition["States"]["Final"]["Resource"] = lambda_arn_2 definition = json.dumps(definition) sm_name = f"catch-{short_uid()}" - aws_client.stepfunctions.create_state_machine( + custom_client.stepfunctions.create_state_machine( name=sm_name, definition=definition, roleArn=role_arn ) # run state machine - sm_arn = get_machine_arn(sm_name, aws_client.stepfunctions) - result = aws_client.stepfunctions.start_execution(stateMachineArn=sm_arn) + sm_arn = get_machine_arn(sm_name, custom_client.stepfunctions) + result = custom_client.stepfunctions.start_execution(stateMachineArn=sm_arn) assert result.get("executionArn") def check_invocations(): # assert that the result is correct - result = _get_execution_results(sm_arn, aws_client.stepfunctions) + result = _get_execution_results(sm_arn, custom_client.stepfunctions) assert {"Hello": TEST_RESULT_VALUE_2} == result.get("handled") # assert that the lambda has been invoked by the SM execution retry(check_invocations, sleep=1, retries=10) # clean up - cleanup(sm_arn, state_machines_before, aws_client.stepfunctions) + cleanup(sm_arn, state_machines_before, custom_client.stepfunctions) # TODO: validate against AWS @markers.aws.unknown - def test_intrinsic_functions(self, aws_client): + def test_intrinsic_functions(self, custom_client): if os.environ.get("AWS_DEFAULT_REGION") != "us-east-1": pytest.skip("skipping non us-east-1 temporarily") - state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] + state_machines_before = custom_client.stepfunctions.list_state_machines()["stateMachines"] # create state machine - role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) + role_arn = arns.role_arn("sfn_role", SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_INTRINSIC_FUNCS) lambda_arn_1 = arns.lambda_function_arn( - TEST_LAMBDA_NAME_5, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + TEST_LAMBDA_NAME_5, SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME ) lambda_arn_2 = arns.lambda_function_arn( - TEST_LAMBDA_NAME_5, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + TEST_LAMBDA_NAME_5, SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME ) if isinstance(definition["States"]["state1"].get("Parameters"), dict): definition["States"]["state1"]["Parameters"]["lambda_params"][ @@ -472,33 +485,33 @@ def test_intrinsic_functions(self, aws_client): definition["States"]["state3"]["Resource"] = lambda_arn_2 definition = json.dumps(definition) sm_name = f"intrinsic-{short_uid()}" - aws_client.stepfunctions.create_state_machine( + custom_client.stepfunctions.create_state_machine( name=sm_name, definition=definition, roleArn=role_arn ) # run state machine - sm_arn = get_machine_arn(sm_name, aws_client.stepfunctions) + sm_arn = get_machine_arn(sm_name, custom_client.stepfunctions) input = {} - result = aws_client.stepfunctions.start_execution( + result = custom_client.stepfunctions.start_execution( stateMachineArn=sm_arn, input=json.dumps(input) ) assert result.get("executionArn") def check_invocations(): # assert that the result is correct - result = _get_execution_results(sm_arn, aws_client.stepfunctions) + result = _get_execution_results(sm_arn, custom_client.stepfunctions) assert {"payload": {"values": [1, "v2"]}} == result.get("result_value") # assert that the lambda has been invoked by the SM execution retry(check_invocations, sleep=1, retries=10) # clean up - cleanup(sm_arn, state_machines_before, aws_client.stepfunctions) + cleanup(sm_arn, state_machines_before, custom_client.stepfunctions) @markers.aws.unknown - def test_events_state_machine(self, aws_client): - events = aws_client.events - state_machines_before = aws_client.stepfunctions.list_state_machines()["stateMachines"] + def test_events_state_machine(self, custom_client): + events = custom_client.events + state_machines_before = custom_client.stepfunctions.list_state_machines()["stateMachines"] # create event bus bus_name = f"bus-{short_uid()}" @@ -509,15 +522,15 @@ def test_events_state_machine(self, aws_client): definition["States"]["step1"]["Parameters"]["Entries"][0]["EventBusName"] = bus_name definition = json.dumps(definition) sm_name = f"events-{short_uid()}" - role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) - aws_client.stepfunctions.create_state_machine( + role_arn = arns.role_arn("sfn_role", SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) + custom_client.stepfunctions.create_state_machine( name=sm_name, definition=definition, roleArn=role_arn ) # run state machine events_before = len(TEST_EVENTS_CACHE) - sm_arn = get_machine_arn(sm_name, aws_client.stepfunctions) - result = aws_client.stepfunctions.start_execution(stateMachineArn=sm_arn) + sm_arn = get_machine_arn(sm_name, custom_client.stepfunctions) + result = custom_client.stepfunctions.start_execution(stateMachineArn=sm_arn) assert result.get("executionArn") def check_invocations(): @@ -533,11 +546,11 @@ def check_invocations(): retry(check_invocations, sleep=1, retries=10) # clean up - cleanup(sm_arn, state_machines_before, aws_client.stepfunctions) + cleanup(sm_arn, state_machines_before, custom_client.stepfunctions) events.delete_event_bus(Name=bus_name) @markers.aws.unknown - def test_create_state_machines_in_parallel(self, cleanups, aws_client): + def test_create_state_machines_in_parallel(self, cleanups, custom_client): """ Perform a test that creates a series of state machines in parallel. Without concurrency control, using StepFunctions-Local, the following error is pretty consistently reproducible: @@ -546,10 +559,10 @@ def test_create_state_machines_in_parallel(self, cleanups, aws_client): CreateStateMachine operation: Invalid State Machine Definition: ''DUPLICATE_STATE_NAME: Duplicate State name: MissingValue at /States/MissingValue', 'DUPLICATE_STATE_NAME: Duplicate State name: Add at /States/Add'' """ - role_arn = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) + role_arn = arns.role_arn("sfn_role", SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) definition = clone(STATE_MACHINE_CHOICE) lambda_arn_4 = arns.lambda_function_arn( - TEST_LAMBDA_NAME_4, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME + TEST_LAMBDA_NAME_4, SF_TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME ) definition["States"]["Add"]["Resource"] = lambda_arn_4 definition = json.dumps(definition) @@ -557,20 +570,22 @@ def test_create_state_machines_in_parallel(self, cleanups, aws_client): def _create_sm(*_): sm_name = f"sm-{short_uid()}" - result = aws_client.stepfunctions.create_state_machine( + result = custom_client.stepfunctions.create_state_machine( name=sm_name, definition=definition, roleArn=role_arn ) assert result["ResponseMetadata"]["HTTPStatusCode"] == 200 cleanups.append( - lambda: aws_client.stepfunctions.delete_state_machine( + lambda: custom_client.stepfunctions.delete_state_machine( stateMachineArn=result["stateMachineArn"] ) ) results.append(result) - aws_client.stepfunctions.describe_state_machine( + custom_client.stepfunctions.describe_state_machine( stateMachineArn=result["stateMachineArn"] ) - aws_client.stepfunctions.list_tags_for_resource(resourceArn=result["stateMachineArn"]) + custom_client.stepfunctions.list_tags_for_resource( + resourceArn=result["stateMachineArn"] + ) num_machines = 30 parallelize(_create_sm, list(range(num_machines)), size=2) @@ -630,10 +645,12 @@ def _create_sm(*_): @pytest.mark.parametrize("statemachine_definition", (TEST_STATE_MACHINE_3,)) # TODO: add sync2 test @markers.aws.unknown def test_multiregion_nested(aws_client_factory, region_name, statemachine_definition): - client1 = aws_client_factory(region_name=region_name).stepfunctions + client1 = aws_client_factory( + region_name=region_name, aws_access_key_id=SF_TEST_AWS_ACCOUNT_ID + ).stepfunctions # create state machine child_machine_name = f"sf-child-{short_uid()}" - role = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) + role = arns.role_arn("sfn_role", SF_TEST_AWS_ACCOUNT_ID, region_name) child_machine_result = client1.create_state_machine( name=child_machine_name, definition=json.dumps(TEST_STATE_MACHINE), roleArn=role ) @@ -641,7 +658,7 @@ def test_multiregion_nested(aws_client_factory, region_name, statemachine_defini # create parent state machine name = f"sf-parent-{short_uid()}" - role = arns.role_arn("sfn_role", TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME) + role = arns.role_arn("sfn_role", SF_TEST_AWS_ACCOUNT_ID, region_name) result = client1.create_state_machine( name=name, definition=json.dumps(statemachine_definition).replace( @@ -681,10 +698,10 @@ def assert_success(): @markers.aws.validated -def test_default_logging_configuration(create_state_machine, aws_client): +def test_default_logging_configuration(create_state_machine, custom_client): role_name = f"role_name-{short_uid()}" try: - role_arn = aws_client.iam.create_role( + role_arn = custom_client.iam.create_role( RoleName=role_name, AssumeRolePolicyDocument=json.dumps(STS_ROLE_POLICY_DOC), )["Role"]["Arn"] @@ -696,18 +713,18 @@ def test_default_logging_configuration(create_state_machine, aws_client): result = create_state_machine(name=sm_name, definition=definition, roleArn=role_arn) assert result["ResponseMetadata"]["HTTPStatusCode"] == 200 - result = aws_client.stepfunctions.describe_state_machine( + result = custom_client.stepfunctions.describe_state_machine( stateMachineArn=result["stateMachineArn"] ) assert result["ResponseMetadata"]["HTTPStatusCode"] == 200 assert result["loggingConfiguration"] == {"level": "OFF", "includeExecutionData": False} finally: - aws_client.iam.delete_role(RoleName=role_name) + custom_client.iam.delete_role(RoleName=role_name) @pytest.mark.skip("Does not work against Pro in new pipeline.") @markers.aws.unknown -def test_aws_sdk_task(sfn_execution_role, aws_client): +def test_aws_sdk_task(sfn_execution_role, custom_client): statemachine_definition = { "StartAt": "CreateTopicTask", "States": { @@ -725,15 +742,15 @@ def test_aws_sdk_task(sfn_execution_role, aws_client): policy_name = f"policy-{short_uid()}" topic_name = f"topic-{short_uid()}" - policy = aws_client.iam.create_policy( + policy = custom_client.iam.create_policy( PolicyDocument='{"Version": "2012-10-17", "Statement": {"Action": "sns:createTopic", "Effect": "Allow", "Resource": "*"}}', PolicyName=policy_name, ) - aws_client.iam.attach_role_policy( + custom_client.iam.attach_role_policy( RoleName=sfn_execution_role["RoleName"], PolicyArn=policy["Policy"]["Arn"] ) - result = aws_client.stepfunctions.create_state_machine( + result = custom_client.stepfunctions.create_state_machine( name=name, definition=json.dumps(statemachine_definition), roleArn=sfn_execution_role["Arn"], @@ -741,13 +758,13 @@ def test_aws_sdk_task(sfn_execution_role, aws_client): machine_arn = result["stateMachineArn"] try: - result = aws_client.stepfunctions.list_state_machines()["stateMachines"] + result = custom_client.stepfunctions.list_state_machines()["stateMachines"] assert len(result) > 0 assert len([sm for sm in result if sm["name"] == name]) == 1 def assert_execution_success(executionArn: str): def _assert_execution_success(): - status = aws_client.stepfunctions.describe_execution(executionArn=executionArn)[ + status = custom_client.stepfunctions.describe_execution(executionArn=executionArn)[ "status" ] if status == "FAILED": @@ -761,35 +778,35 @@ def _retry_execution(): # start state machine execution # AWS initially straight up fails until the permissions seem to take effect # so we wait until the statemachine is at least running - result = aws_client.stepfunctions.start_execution( + result = custom_client.stepfunctions.start_execution( stateMachineArn=machine_arn, input='{"Name": "' f"{topic_name}" '"}' ) assert wait_until(assert_execution_success(result["executionArn"])) - describe_result = aws_client.stepfunctions.describe_execution( + describe_result = custom_client.stepfunctions.describe_execution( executionArn=result["executionArn"] ) output = describe_result["output"] assert topic_name in output - result = aws_client.stepfunctions.describe_state_machine_for_execution( + result = custom_client.stepfunctions.describe_state_machine_for_execution( executionArn=result["executionArn"] ) assert result["stateMachineArn"] == machine_arn topic_arn = json.loads(describe_result["output"])["TopicArn"] - topics = aws_client.sns.list_topics() + topics = custom_client.sns.list_topics() assert topic_arn in [t["TopicArn"] for t in topics["Topics"]] - aws_client.sns.delete_topic(TopicArn=topic_arn) + custom_client.sns.delete_topic(TopicArn=topic_arn) return True assert wait_until(_retry_execution, max_retries=3, strategy="linear", wait=3.0) finally: - aws_client.iam.delete_policy(PolicyArn=policy["Policy"]["Arn"]) - aws_client.stepfunctions.delete_state_machine(stateMachineArn=machine_arn) + custom_client.iam.delete_policy(PolicyArn=policy["Policy"]["Arn"]) + custom_client.stepfunctions.delete_state_machine(stateMachineArn=machine_arn) @pytest.mark.skip("Does not work against Pro in new pipeline.") @markers.aws.unknown -def test_aws_sdk_task_delete_s3_object(s3_bucket, sfn_execution_role, aws_client): +def test_aws_sdk_task_delete_s3_object(s3_bucket, sfn_execution_role, custom_client): s3_key = "test-key" statemachine_definition = { "StartAt": "CreateTopicTask", @@ -804,21 +821,21 @@ def test_aws_sdk_task_delete_s3_object(s3_bucket, sfn_execution_role, aws_client } # create state machine - aws_client.s3.put_object(Bucket=s3_bucket, Key=s3_key, Body=b"") + custom_client.s3.put_object(Bucket=s3_bucket, Key=s3_key, Body=b"") name = f"statemachine-{short_uid()}" - result = aws_client.stepfunctions.create_state_machine( + result = custom_client.stepfunctions.create_state_machine( name=name, definition=json.dumps(statemachine_definition), roleArn=sfn_execution_role["Arn"], ) machine_arn = result["stateMachineArn"] - result = aws_client.stepfunctions.start_execution(stateMachineArn=machine_arn, input="{}") + result = custom_client.stepfunctions.start_execution(stateMachineArn=machine_arn, input="{}") execution_arn = result["executionArn"] await_sfn_execution_result(execution_arn) with pytest.raises(Exception) as exc: - aws_client.s3.head_object(Bucket=s3_bucket, Key=s3_key) + custom_client.s3.head_object(Bucket=s3_bucket, Key=s3_key) assert "Not Found" in str(exc) From b2917c09a70d310f3d8b94dd3d3dece143f0d62b Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Wed, 13 Sep 2023 18:24:28 +0530 Subject: [PATCH 08/14] Use an alternative way to pass the context --- .../asl/eval/contextobject/contex_object.py | 2 -- .../services/stepfunctions/asl/eval/environment.py | 14 ++++++++++++-- .../services/stepfunctions/backend/execution.py | 4 ++-- .../stepfunctions/backend/execution_worker.py | 10 +++++++++- localstack/services/stepfunctions/backend/utils.py | 5 ++--- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/localstack/services/stepfunctions/asl/eval/contextobject/contex_object.py b/localstack/services/stepfunctions/asl/eval/contextobject/contex_object.py index ec44f93b73573..33fbadc30cf11 100644 --- a/localstack/services/stepfunctions/asl/eval/contextobject/contex_object.py +++ b/localstack/services/stepfunctions/asl/eval/contextobject/contex_object.py @@ -9,8 +9,6 @@ class Execution(TypedDict): Name: str RoleArn: str StartTime: str # Format: ISO 8601. - AccountId: str - RegionName: str class State(TypedDict): diff --git a/localstack/services/stepfunctions/asl/eval/environment.py b/localstack/services/stepfunctions/asl/eval/environment.py index 9bcc001c0c8d5..fe4915223cddf 100644 --- a/localstack/services/stepfunctions/asl/eval/environment.py +++ b/localstack/services/stepfunctions/asl/eval/environment.py @@ -26,8 +26,14 @@ class Environment: - def __init__(self, context_object_init: ContextObjectInitData): + def __init__( + self, account_id: str, region_name: str, context_object_init: ContextObjectInitData + ): super(Environment, self).__init__() + + self.account_id = account_id + self.region_name = region_name + self._state_mutex = threading.RLock() self._program_state: Optional[ProgramState] = None self.program_state_event = threading.Event() @@ -54,7 +60,11 @@ def as_frame_of(cls, env: Environment): Execution=env.context_object_manager.context_object["Execution"], StateMachine=env.context_object_manager.context_object["StateMachine"], ) - frame = cls(context_object_init=context_object_init) + frame = cls( + account_id=env.account_id, + region_name=env.region_name, + context_object_init=context_object_init, + ) frame._is_frame = True frame.event_history = env.event_history frame.callback_pool_manager = env.callback_pool_manager diff --git a/localstack/services/stepfunctions/backend/execution.py b/localstack/services/stepfunctions/backend/execution.py index 08fcc6c2846a7..f667a94eaece2 100644 --- a/localstack/services/stepfunctions/backend/execution.py +++ b/localstack/services/stepfunctions/backend/execution.py @@ -210,6 +210,8 @@ def start(self) -> None: raise InvalidName() # TODO. self.exec_worker = ExecutionWorker( + account_id=self.account_id, + region_name=self.region_name, role_arn=self.role_arn, definition=self.state_machine.definition, input_data=self.input_data, @@ -221,8 +223,6 @@ def start(self) -> None: Name=self.name, RoleArn=self.role_arn, StartTime=self.start_date.time().isoformat(), - AccountId=self.account_id, - RegionName=self.region_name, ), StateMachine=ContextObjectStateMachine( Id=self.state_machine.arn, diff --git a/localstack/services/stepfunctions/backend/execution_worker.py b/localstack/services/stepfunctions/backend/execution_worker.py index 7e326a0dbab5b..22401084102a3 100644 --- a/localstack/services/stepfunctions/backend/execution_worker.py +++ b/localstack/services/stepfunctions/backend/execution_worker.py @@ -31,12 +31,16 @@ class ExecutionWorker: def __init__( self, + account_id: str, + region_name: str, role_arn: Arn, definition: Definition, input_data: Optional[dict], context_object_init: ContextObjectInitData, exec_comm: ExecutionWorkerComm, ): + self.account_id = account_id + self.region_name = region_name self.role_arn = role_arn self.definition = definition self.input_data = input_data @@ -46,7 +50,11 @@ def __init__( def _execution_logic(self): program: Program = AmazonStateLanguageParser.parse(self.definition) - self.env = Environment(context_object_init=self._context_object_init) + self.env = Environment( + account_id=self.account_id, + region_name=self.region_name, + context_object_init=self._context_object_init, + ) self.env.inp = copy.deepcopy( self.input_data ) # The program will mutate the input_data, which is otherwise constant in regard to the execution value. diff --git a/localstack/services/stepfunctions/backend/utils.py b/localstack/services/stepfunctions/backend/utils.py index b54c2bd4dc501..d89817f8c71a6 100644 --- a/localstack/services/stepfunctions/backend/utils.py +++ b/localstack/services/stepfunctions/backend/utils.py @@ -5,10 +5,9 @@ def get_boto_client(env: Environment, service: str): - execution = env.context_object_manager.context_object["Execution"] return connect_to.get_client( - aws_access_key_id=execution["AccountId"], - region_name=execution["RegionName"], + aws_access_key_id=env.account_id, + region_name=env.region_name, service_name=service, config=Config(parameter_validation=False), ) From e43191ccc677e92988979fac581a437bc2769f9a Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Thu, 14 Sep 2023 13:08:38 +0530 Subject: [PATCH 09/14] Use request context account ID and region during program construction --- .../state_execution/state_task/service/resource.py | 7 ++++--- .../services/stepfunctions/asl/parse/asl_parser.py | 4 ++-- .../stepfunctions/asl/parse/preprocessor.py | 6 +++++- .../stepfunctions/backend/execution_worker.py | 4 +++- localstack/services/stepfunctions/provider_v2.py | 14 ++++++++++---- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py index f790e8cd03e7c..c8cd46b96ce1e 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py @@ -5,7 +5,6 @@ from typing import Final, Optional from localstack.services.stepfunctions.asl.component.component import Component -from localstack.utils.aws import aws_stack class ResourceCondition(str): @@ -85,10 +84,12 @@ def __init__(self, resource_arn: ResourceARN): self.account = resource_arn.account @staticmethod - def from_resource_arn(arn: str) -> Resource: + def from_resource_arn(account_id: str, region_name: str, arn: str) -> Resource: resource_arn = ResourceARN.from_arn(arn) + if not resource_arn.account: + resource_arn.account = account_id if not resource_arn.region: - resource_arn.region = aws_stack.get_region() + resource_arn.region = region_name match resource_arn.service, resource_arn.task_type: case "lambda", "function": return LambdaResource(resource_arn=resource_arn) diff --git a/localstack/services/stepfunctions/asl/parse/asl_parser.py b/localstack/services/stepfunctions/asl/parse/asl_parser.py index 260c326523c59..0d21557ef5f99 100644 --- a/localstack/services/stepfunctions/asl/parse/asl_parser.py +++ b/localstack/services/stepfunctions/asl/parse/asl_parser.py @@ -11,13 +11,13 @@ class AmazonStateLanguageParser(abc.ABC): @staticmethod - def parse(src: str) -> Program: + def parse(account_id: str, region_name: str, src: str) -> Program: input_stream = InputStream(src) lexer = ASLLexer(input_stream) stream = CommonTokenStream(lexer) parser = ASLParser(stream) parser._errHandler = antlr4.BailErrorStrategy() tree = parser.program_decl() - preprocessor = Preprocessor() + preprocessor = Preprocessor(account_id, region_name) program = preprocessor.visit(tree) return program diff --git a/localstack/services/stepfunctions/asl/parse/preprocessor.py b/localstack/services/stepfunctions/asl/parse/preprocessor.py index 909c9a144ffc5..53db7856a94eb 100644 --- a/localstack/services/stepfunctions/asl/parse/preprocessor.py +++ b/localstack/services/stepfunctions/asl/parse/preprocessor.py @@ -191,6 +191,10 @@ class Preprocessor(ASLParserVisitor): + def __init__(self, account_id: str, region_name: str): + self.account_id = account_id + self.region_name = region_name + @staticmethod def _inner_string_of(parse_tree: ParseTree) -> Optional[str]: if Antlr4Utils.is_terminal(parse_tree, ASLLexer.NULL): @@ -231,7 +235,7 @@ def visitState_type(self, ctx: ASLParser.State_typeContext) -> StateType: def visitResource_decl(self, ctx: ASLParser.Resource_declContext) -> Resource: inner_str = self._inner_string_of(parse_tree=ctx.keyword_or_string()) - return Resource.from_resource_arn(inner_str) + return Resource.from_resource_arn(self.account_id, self.region_name, inner_str) def visitEnd_decl(self, ctx: ASLParser.End_declContext) -> End: bool_child: ParseTree = ctx.children[-1] diff --git a/localstack/services/stepfunctions/backend/execution_worker.py b/localstack/services/stepfunctions/backend/execution_worker.py index 22401084102a3..0569aa6132a7d 100644 --- a/localstack/services/stepfunctions/backend/execution_worker.py +++ b/localstack/services/stepfunctions/backend/execution_worker.py @@ -49,7 +49,9 @@ def __init__( self.exec_comm = exec_comm def _execution_logic(self): - program: Program = AmazonStateLanguageParser.parse(self.definition) + program: Program = AmazonStateLanguageParser.parse( + self.account_id, self.region_name, self.definition + ) self.env = Environment( account_id=self.account_id, region_name=self.region_name, diff --git a/localstack/services/stepfunctions/provider_v2.py b/localstack/services/stepfunctions/provider_v2.py index 7936812080ff0..e649db89930cd 100644 --- a/localstack/services/stepfunctions/provider_v2.py +++ b/localstack/services/stepfunctions/provider_v2.py @@ -148,11 +148,11 @@ def _revision_by_name( return None @staticmethod - def _validate_definition(definition: str): + def _validate_definition(account_id: str, region_name: str, definition: str): # Validate # TODO: pass through static analyser. try: - AmazonStateLanguageParser.parse(definition) + AmazonStateLanguageParser.parse(account_id, region_name, definition) except Exception as ex: # TODO: add message from static analyser, this just helps the user debug issues in the derivation. raise InvalidDefinition(f"Error '{str(ex)}' in definition '{definition}'.") @@ -182,7 +182,11 @@ def create_state_machine( ) state_machine_definition: str = request["definition"] - StepFunctionsProvider._validate_definition(definition=state_machine_definition) + StepFunctionsProvider._validate_definition( + account_id=context.account_id, + region_name=context.region, + definition=state_machine_definition, + ) name: Optional[Name] = request["name"] arn = aws_stack_state_machine_arn( @@ -498,7 +502,9 @@ def update_state_machine( ) if definition is not None: - self._validate_definition(definition=definition) + self._validate_definition( + account_id=context.account_id, region_name=context.region, definition=definition + ) revision_id = state_machine.create_revision(definition=definition, role_arn=role_arn) From aad9b80492fdc9f7b22f48af0184cb5febf6dd23 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Thu, 14 Sep 2023 15:47:20 +0530 Subject: [PATCH 10/14] Maintain backward compatibility --- localstack/config.py | 2 ++ localstack/services/stepfunctions/stepfunctions_starter.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/localstack/config.py b/localstack/config.py index 1b2c1725e9d4f..8f6b1f29af174 100644 --- a/localstack/config.py +++ b/localstack/config.py @@ -946,6 +946,8 @@ def legacy_fallback(envar_name: str, default: T) -> T: # DEV: sbx_user1051 (default when not provided) Alternative system user or empty string to skip dropping privileges. LAMBDA_INIT_USER = os.environ.get("LAMBDA_INIT_USER") +# Adding Stepfunctions default port +LOCAL_PORT_STEPFUNCTIONS = int(os.environ.get("LOCAL_PORT_STEPFUNCTIONS") or 8083) # Stepfunctions lambda endpoint override STEPFUNCTIONS_LAMBDA_ENDPOINT = os.environ.get("STEPFUNCTIONS_LAMBDA_ENDPOINT", "").strip() diff --git a/localstack/services/stepfunctions/stepfunctions_starter.py b/localstack/services/stepfunctions/stepfunctions_starter.py index b0f10534631f9..c9f064be76ab8 100644 --- a/localstack/services/stepfunctions/stepfunctions_starter.py +++ b/localstack/services/stepfunctions/stepfunctions_starter.py @@ -5,7 +5,7 @@ from localstack import config from localstack.services.stepfunctions.packages import stepfunctions_local_package from localstack.utils.aws import aws_stack -from localstack.utils.net import get_free_tcp_port +from localstack.utils.net import get_free_tcp_port, port_can_be_bound from localstack.utils.run import ShellCommandThread from localstack.utils.serving import Server from localstack.utils.threads import TMP_THREADS, FuncThread @@ -138,7 +138,9 @@ def shutdown_all(self): def _create_stepfunctions_server( self, account_id: str, region_name: str ) -> StepFunctionsServer: - port = get_free_tcp_port() + port = config.LOCAL_PORT_STEPFUNCTIONS + if not port_can_be_bound(port): + port = get_free_tcp_port() stepfunctions_local_package.install() server = StepFunctionsServer( From 74d02eb630f48f4ddf558a7976726081356ec463 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Thu, 14 Sep 2023 17:40:14 +0530 Subject: [PATCH 11/14] Use proper client utility for eventbridge --- .../state_task/service/state_task_service_events.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py index eb7f56e7b3f6d..aedfdc68131aa 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py @@ -1,10 +1,7 @@ import json from typing import Final, Optional -from botocore.config import Config - from localstack.aws.api.stepfunctions import HistoryEventType, TaskFailedEventDetails -from localstack.aws.connect import connect_externally_to from localstack.services.stepfunctions.asl.component.common.error_name.custom_error_name import ( CustomErrorName, ) @@ -18,6 +15,7 @@ from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails from localstack.services.stepfunctions.asl.utils.encoding import to_json_str +from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.strings import camel_to_snake_case @@ -76,7 +74,7 @@ def _normalised_request_parameters(env: Environment, parameters: dict): def _eval_service_task(self, env: Environment, parameters: dict) -> None: self._normalised_request_parameters(env=env, parameters=parameters) api_action = camel_to_snake_case(self.resource.api_action) - events_client = connect_externally_to(config=Config(parameter_validation=False)).events + events_client = get_boto_client(env, "events") response = getattr(events_client, api_action)(**parameters) response.pop("ResponseMetadata", None) From 54ba905756349381141b2e389145b9d35fdf5262 Mon Sep 17 00:00:00 2001 From: MEPalma Date: Mon, 2 Oct 2023 12:15:37 +0200 Subject: [PATCH 12/14] resources as evaluation components --- .../common/error_name/failure_event.py | 4 + .../state_task/lambda_eval_utils.py | 8 +- .../state_task/service/resource.py | 117 ++++++++++++++++-- .../state_task/service/state_task_service.py | 18 ++- .../service/state_task_service_api_gateway.py | 9 +- .../service/state_task_service_aws_sdk.py | 15 ++- .../service/state_task_service_dynamodb.py | 17 ++- .../service/state_task_service_events.py | 19 ++- .../service/state_task_service_lambda.py | 11 +- .../service/state_task_service_sfn.py | 32 +++-- .../service/state_task_service_sns.py | 15 ++- .../service/state_task_service_sqs.py | 15 ++- .../state_task/state_task_lambda.py | 7 +- .../asl/eval/aws_execution_details.py | 12 ++ .../stepfunctions/asl/eval/environment.py | 13 +- .../stepfunctions/asl/parse/asl_parser.py | 4 +- .../stepfunctions/asl/parse/preprocessor.py | 6 +- .../stepfunctions/asl/utils/boto_client.py | 13 ++ .../stepfunctions/backend/execution.py | 7 +- .../stepfunctions/backend/execution_worker.py | 41 +++--- .../services/stepfunctions/backend/utils.py | 13 -- .../services/stepfunctions/provider_v2.py | 18 +-- localstack/services/stores.py | 2 +- 23 files changed, 296 insertions(+), 120 deletions(-) create mode 100644 localstack/services/stepfunctions/asl/eval/aws_execution_details.py create mode 100644 localstack/services/stepfunctions/asl/utils/boto_client.py delete mode 100644 localstack/services/stepfunctions/backend/utils.py diff --git a/localstack/services/stepfunctions/asl/component/common/error_name/failure_event.py b/localstack/services/stepfunctions/asl/component/common/error_name/failure_event.py index 94fc2eb9bb6de..839eaba4d6318 100644 --- a/localstack/services/stepfunctions/asl/component/common/error_name/failure_event.py +++ b/localstack/services/stepfunctions/asl/component/common/error_name/failure_event.py @@ -3,6 +3,7 @@ from localstack.aws.api.stepfunctions import ExecutionFailedEventDetails, HistoryEventType from localstack.services.stepfunctions.asl.component.common.error_name.error_name import ErrorName from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails +from localstack.services.stepfunctions.asl.utils.encoding import to_json_str class FailureEvent: @@ -27,6 +28,9 @@ class FailureEventException(Exception): def __init__(self, failure_event: FailureEvent): self.failure_event = failure_event + def __str__(self) -> str: + return to_json_str(self.failure_event.event_details) + def get_execution_failed_event_details(self) -> Optional[ExecutionFailedEventDetails]: if self.failure_event.event_details is None: return None diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py index 6439b1783cbd3..eb66a21392878 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py @@ -4,8 +4,8 @@ from localstack.aws.api.lambda_ import InvocationResponse from localstack.services.stepfunctions.asl.eval.environment import Environment +from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for from localstack.services.stepfunctions.asl.utils.encoding import to_json_str -from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.collections import select_from_typed_dict from localstack.utils.run import to_str from localstack.utils.strings import to_bytes @@ -20,8 +20,9 @@ def __init__(self, function_error: Optional[str], payload: str): self.payload = payload -def exec_lambda_function(env: Environment, parameters: dict) -> None: - lambda_client = get_boto_client(env, "lambda") +def exec_lambda_function(env: Environment, parameters: dict, region: str, account: str) -> None: + lambda_client = boto_client_for(region=region, account=account, service="lambda") + invocation_resp: InvocationResponse = lambda_client.invoke(**parameters) func_error: Optional[str] = invocation_resp.get("FunctionError") @@ -33,7 +34,6 @@ def exec_lambda_function(env: Environment, parameters: dict) -> None: resp_payload = invocation_resp["Payload"].read() resp_payload_str = to_str(resp_payload) resp_payload_json: json = json.loads(resp_payload_str) - # resp_payload_value = resp_payload_json if resp_payload_json is not None else dict() invocation_resp["Payload"] = resp_payload_json response = select_from_typed_dict(typed_dict=InvocationResponse, obj=invocation_resp) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py index c8cd46b96ce1e..1d4ac9abf9b47 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py @@ -4,7 +4,8 @@ from itertools import takewhile from typing import Final, Optional -from localstack.services.stepfunctions.asl.component.component import Component +from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent +from localstack.services.stepfunctions.asl.eval.environment import Environment class ResourceCondition(str): @@ -71,25 +72,33 @@ def from_arn(cls, arn: str) -> ResourceARN: ) -class Resource(Component, abc.ABC): +class Resource(EvalComponent, abc.ABC): + class ResourceOutput: + resource_arn: Final[str] + partition: Final[str] + region: Final[str] + account: Final[str] + + def __init__(self, resource_arn: str, partition: str, region: str, account: str): + self.resource_arn = resource_arn + self.partition = partition + self.region = region + self.account = account + + _region: Final[str] + _account: Final[str] resource_arn: Final[str] partition: Final[str] - region: Final[str] - account: Final[str] def __init__(self, resource_arn: ResourceARN): + self._region = resource_arn.region + self._account = resource_arn.account self.resource_arn = resource_arn.arn self.partition = resource_arn.partition - self.region = resource_arn.region - self.account = resource_arn.account @staticmethod - def from_resource_arn(account_id: str, region_name: str, arn: str) -> Resource: + def from_resource_arn(arn: str) -> Resource: resource_arn = ResourceARN.from_arn(arn) - if not resource_arn.account: - resource_arn.account = account_id - if not resource_arn.region: - resource_arn.region = region_name match resource_arn.service, resource_arn.task_type: case "lambda", "function": return LambdaResource(resource_arn=resource_arn) @@ -98,24 +107,97 @@ def from_resource_arn(account_id: str, region_name: str, arn: str) -> Resource: case "states", _: return ServiceResource(resource_arn=resource_arn) + def _build_resource(self, env: Environment) -> Resource.ResourceOutput: + region = self._region if self._region else env.aws_execution_details.region + account = self._account if self._account else env.aws_execution_details.account + return Resource.ResourceOutput( + resource_arn=self.resource_arn, + partition=self.partition, + region=region, + account=account, + ) + + def _eval_body(self, env: Environment) -> None: + resource_output = self._build_resource(env=env) + env.stack.append(resource_output) + class ActivityResource(Resource): + class ActivityResourceOutput(Resource.ResourceOutput): + name: Final[str] + + def __init__(self, resource_arn: str, partition: str, region: str, account: str, name: str): + super().__init__( + resource_arn=resource_arn, partition=partition, region=region, account=account + ) + self.name = name + name: Final[str] def __init__(self, resource_arn: ResourceARN): super().__init__(resource_arn=resource_arn) self.name = resource_arn.name + def _build_resource(self, env: Environment) -> Resource.ResourceOutput: + resource_output: Resource.ResourceOutput = super()._build_resource(env=env) + activity_resource_output = ActivityResource.ActivityResourceOutput( + **vars(resource_output), name=self.name + ) + return activity_resource_output + class LambdaResource(Resource): + class LambdaResourceOutput(Resource.ResourceOutput): + function_name: Final[str] + + def __init__( + self, resource_arn: str, partition: str, region: str, account: str, function_name: str + ): + super().__init__( + resource_arn=resource_arn, partition=partition, region=region, account=account + ) + self.function_name = function_name + function_name: Final[str] def __init__(self, resource_arn: ResourceARN): super().__init__(resource_arn=resource_arn) - self.function_name: str = resource_arn.name + self.function_name = resource_arn.name + + def _build_resource(self, env: Environment) -> Resource.ResourceOutput: + resource_output: Resource.ResourceOutput = super()._build_resource(env=env) + lambda_resource_output = LambdaResource.LambdaResourceOutput( + **vars(resource_output), function_name=self.function_name + ) + return lambda_resource_output class ServiceResource(Resource): + class ServiceResourceOutput(Resource.ResourceOutput): + service_name: Final[str] + api_name: Final[str] + api_action: Final[str] + condition: Final[Optional[str]] + + def __init__( + self, + resource_arn: str, + partition: str, + region: str, + account: str, + service_name: str, + api_name: str, + api_action: str, + condition: Optional[str], + ): + super().__init__( + resource_arn=resource_arn, partition=partition, region=region, account=account + ) + self.service_name = service_name + self.api_name = api_name + self.api_action = api_action + self.condition = condition + service_name: Final[str] api_name: Final[str] api_action: Final[str] @@ -147,3 +229,14 @@ def __init__(self, resource_arn: ResourceARN): self.condition = ResourceCondition.Sync2 case unsupported: raise RuntimeError(f"Unsupported condition '{unsupported}'.") + + def _build_resource(self, env: Environment) -> Resource.ResourceOutput: + resource_output: Resource.ResourceOutput = super()._build_resource(env=env) + lambda_resource_output = ServiceResource.ServiceResourceOutput( + **vars(resource_output), + service_name=self.service_name, + api_name=self.api_name, + api_action=self.api_action, + condition=self.condition, + ) + return lambda_resource_output diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py index 62f85933a29dd..f861f265aea90 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py @@ -54,16 +54,20 @@ def _get_timed_out_failure_event(self) -> FailureEvent: ) @abc.abstractmethod - def _eval_service_task(self, env: Environment, parameters: dict): + def _eval_service_task( + self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + ): ... - def _before_eval_execution(self, env: Environment, parameters: dict) -> None: + def _before_eval_execution( + self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + ) -> None: parameters_str = to_json_str(parameters) scheduled_event_details = TaskScheduledEventDetails( resource=self._get_sfn_resource(), resourceType=self._get_sfn_resource_type(), - region=self.resource.region, + region=resource.region, parameters=parameters_str, ) if not self.timeout.is_default_value(): @@ -104,10 +108,14 @@ def _after_eval_execution(self, env: Environment) -> None: def _eval_execution(self, env: Environment) -> None: parameters = self._eval_parameters(env=env) - self._before_eval_execution(env=env, parameters=parameters) + + self.resource.eval(env=env) + resource_output: ServiceResource.ServiceResourceOutput = env.stack.pop() + + self._before_eval_execution(env=env, resource=resource_output, parameters=parameters) normalised_parameters = self._normalised_parameters_bindings(parameters) - self._eval_service_task(env=env, parameters=normalised_parameters) + self._eval_service_task(env=env, resource=resource_output, parameters=normalised_parameters) self._after_eval_execution(env=env) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py index 483e58579bcff..fd1644491d96c 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py @@ -24,6 +24,9 @@ from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import ( FailureEvent, ) +from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( + ServiceResource, +) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, ) @@ -98,7 +101,7 @@ def __init__(self, parameters: TaskParameters, response: Response): class StateTaskServiceApiGateway(StateTaskServiceCallback): _SUPPORTED_API_PARAM_BINDINGS: Final[dict[str, set[str]]] = { - SupportedApiCalls.invoke: set(TaskParameters.__required_keys__) + SupportedApiCalls.invoke: set(TaskParameters.__required_keys__) # noqa } _FORBIDDEN_HTTP_HEADERS_PREFIX: Final[set[str]] = {"X-Forwarded", "X-Amz", "X-Amzn"} @@ -246,7 +249,9 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: ), ) - def _eval_service_task(self, env: Environment, parameters: dict): + def _eval_service_task( + self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + ): task_parameters: TaskParameters = select_from_typed_dict( typed_dict=TaskParameters, obj=parameters ) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py index 9c009f5a95fcb..5c01f0d21c070 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py @@ -11,13 +11,16 @@ from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name_type import ( StatesErrorNameType, ) +from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( + ServiceResource, +) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, ) from localstack.services.stepfunctions.asl.component.state.state_props import StateProps from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails -from localstack.services.stepfunctions.backend.utils import get_boto_client +from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for from localstack.utils.common import camel_to_snake_case @@ -103,8 +106,14 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: return failure_event return super()._from_error(env=env, ex=ex) - def _eval_service_task(self, env: Environment, parameters: dict) -> None: - api_client = get_boto_client(env, self._normalised_api_name) + def _eval_service_task( + self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + ): + api_client = boto_client_for( + region=resource.region, + account=resource.account, + service=self._normalised_api_name, + ) response = getattr(api_client, self._normalised_api_action)(**parameters) or dict() if response: response.pop("ResponseMetadata", None) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py index a2f0309a299bd..eb7fc57c45cbb 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py @@ -9,12 +9,15 @@ from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import ( FailureEvent, ) +from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( + ServiceResource, +) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service import ( StateTaskService, ) from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails -from localstack.services.stepfunctions.backend.utils import get_boto_client +from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for from localstack.utils.strings import camel_to_snake_case @@ -121,9 +124,15 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: ), ) - def _eval_service_task(self, env: Environment, parameters: dict) -> None: - api_action = camel_to_snake_case(self.resource.api_action) - dynamodb_client = get_boto_client(env, "dynamodb") + def _eval_service_task( + self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + ): + api_action = camel_to_snake_case(resource.api_action) + dynamodb_client = boto_client_for( + region=resource.region, + account=resource.account, + service="dynamodb", + ) response = getattr(dynamodb_client, api_action)(**parameters) response.pop("ResponseMetadata", None) env.stack.append(response) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py index aedfdc68131aa..eeec727d22b5e 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py @@ -9,13 +9,16 @@ from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import ( FailureEvent, ) +from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( + ServiceResource, +) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, ) from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails +from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for from localstack.services.stepfunctions.asl.utils.encoding import to_json_str -from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.strings import camel_to_snake_case @@ -71,16 +74,22 @@ def _normalised_request_parameters(env: Environment, parameters: dict): resources.append(env.context_object_manager.context_object["Execution"]["Id"]) entry["Resources"] = resources - def _eval_service_task(self, env: Environment, parameters: dict) -> None: + def _eval_service_task( + self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + ): self._normalised_request_parameters(env=env, parameters=parameters) - api_action = camel_to_snake_case(self.resource.api_action) - events_client = get_boto_client(env, "events") + api_action = camel_to_snake_case(resource.api_action) + events_client = boto_client_for( + region=resource.region, + account=resource.account, + service="events", + ) response = getattr(events_client, api_action)(**parameters) response.pop("ResponseMetadata", None) # If the response from PutEvents contains a non-zero FailedEntryCount then the # Task state fails with the error EventBridge.FailedEntry. - if self.resource.api_action == "putevents": + if resource.api_action == "putevents": failed_entry_count = response.get("FailedEntryCount", 0) if failed_entry_count > 0: # TODO: pipe events' cause in the exception object. At them moment diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py index f75722d4ecdbe..81d75b99dfac8 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py @@ -12,6 +12,9 @@ from localstack.services.stepfunctions.asl.component.state.state_execution.state_task import ( lambda_eval_utils, ) +from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( + ServiceResource, +) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, ) @@ -75,7 +78,11 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: ), ) - def _eval_service_task(self, env: Environment, parameters: dict) -> None: + def _eval_service_task( + self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + ): if "Payload" in parameters: parameters["Payload"] = lambda_eval_utils.to_payload_type(parameters["Payload"]) - lambda_eval_utils.exec_lambda_function(env=env, parameters=parameters) + lambda_eval_utils.exec_lambda_function( + env=env, parameters=parameters, region=resource.region, account=resource.account + ) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py index a139a1ac7feaf..c2c42b7a695f9 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py @@ -22,13 +22,16 @@ from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name_type import ( StatesErrorNameType, ) +from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( + ServiceResource, +) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, ) from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails +from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for from localstack.services.stepfunctions.asl.utils.encoding import to_json_str -from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.collections import select_from_typed_dict from localstack.utils.strings import camel_to_snake_case @@ -159,16 +162,31 @@ def _replace_with_json_if_str(key: str) -> None: _replace_with_json_if_str("input") _replace_with_json_if_str("output") - def _eval_service_task(self, env: Environment, parameters: dict) -> None: - api_action = camel_to_snake_case(self.resource.api_action) - sfn_client = get_boto_client(env, "stepfunctions") + def _eval_service_task( + self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + ): + api_action = camel_to_snake_case(resource.api_action) + sfn_client = boto_client_for( + region=resource.region, + account=resource.account, + service="stepfunctions", + ) response = getattr(sfn_client, api_action)(**parameters) response.pop("ResponseMetadata", None) - self._normalise_botocore_response(self.resource.api_action, response) + self._normalise_botocore_response(response.api_action, response) env.stack.append(response) - def _sync_to_start_machine(self, env: Environment, sync2_response: bool) -> None: - sfn_client = get_boto_client(env, "stepfunctions") + def _sync_to_start_machine( + self, + env: Environment, + resource: ServiceResource.ServiceResourceOutput, + sync2_response: bool, + ) -> None: + sfn_client = boto_client_for( + region=resource.region, + account=resource.account, + service="stepfunctions", + ) submission_output: dict = env.stack.pop() execution_arn: str = submission_output["ExecutionArn"] diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py index 3d8d59f0c4aa5..a3514864a417d 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py @@ -9,12 +9,15 @@ from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import ( FailureEvent, ) +from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( + ServiceResource, +) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, ) from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails -from localstack.services.stepfunctions.backend.utils import get_boto_client +from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for from localstack.utils.strings import camel_to_snake_case @@ -69,9 +72,15 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: ) return super()._from_error(env=env, ex=ex) - def _eval_service_task(self, env: Environment, parameters: dict) -> None: + def _eval_service_task( + self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + ): api_action = camel_to_snake_case(self.resource.api_action) - sns_client = get_boto_client(env, "sns") + sns_client = boto_client_for( + region=resource.region, + account=resource.account, + service="sns", + ) response = getattr(sns_client, api_action)(**parameters) response.pop("ResponseMetadata", None) env.stack.append(response) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py index ae9dd10860c39..9510af03b9248 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py @@ -9,13 +9,16 @@ from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import ( FailureEvent, ) +from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( + ServiceResource, +) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, ) from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails +from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for from localstack.services.stepfunctions.asl.utils.encoding import to_json_str -from localstack.services.stepfunctions.backend.utils import get_boto_client from localstack.utils.strings import camel_to_snake_case @@ -55,7 +58,9 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: ) return super()._from_error(env=env, ex=ex) - def _eval_service_task(self, env: Environment, parameters: dict) -> None: + def _eval_service_task( + self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + ): # TODO: Stepfunctions automatically dumps to json MessageBody's definitions. # Are these other similar scenarios? if "MessageBody" in parameters: @@ -64,7 +69,11 @@ def _eval_service_task(self, env: Environment, parameters: dict) -> None: parameters["MessageBody"] = to_json_str(message_body) api_action = camel_to_snake_case(self.resource.api_action) - sqs_client = get_boto_client(env, "sqs") + sqs_client = boto_client_for( + region=resource.region, + account=resource.account, + service="sqs", + ) response = getattr(sqs_client, api_action)(**parameters) response.pop("ResponseMetadata", None) env.stack.append(response) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py index 6ff2789401626..385cec06fb474 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py @@ -121,7 +121,12 @@ def _eval_execution(self, env: Environment) -> None: if "Payload" in parameters: parameters["Payload"] = lambda_eval_utils.to_payload_type(parameters["Payload"]) - lambda_eval_utils.exec_lambda_function(env=env, parameters=parameters) + self.resource.eval(env=env) + resource = env.stack.pop() + + lambda_eval_utils.exec_lambda_function( + env=env, parameters=parameters, region=resource.region, account=resource.account + ) # In lambda invocations, only payload is passed on as output. output = env.stack.pop() diff --git a/localstack/services/stepfunctions/asl/eval/aws_execution_details.py b/localstack/services/stepfunctions/asl/eval/aws_execution_details.py new file mode 100644 index 0000000000000..495d870ae2d45 --- /dev/null +++ b/localstack/services/stepfunctions/asl/eval/aws_execution_details.py @@ -0,0 +1,12 @@ +from typing import Final + + +class AWSExecutionDetails: + account: Final[str] + region: Final[str] + role_arn: Final[str] + + def __init__(self, account: str, region: str, role_arn: str): + self.account = account + self.region = region + self.role_arn = role_arn diff --git a/localstack/services/stepfunctions/asl/eval/environment.py b/localstack/services/stepfunctions/asl/eval/environment.py index fe4915223cddf..8d00f7b2bf8ba 100644 --- a/localstack/services/stepfunctions/asl/eval/environment.py +++ b/localstack/services/stepfunctions/asl/eval/environment.py @@ -3,9 +3,10 @@ import copy import logging import threading -from typing import Any, Optional +from typing import Any, Final, Optional from localstack.aws.api.stepfunctions import ExecutionFailedEventDetails, Timestamp +from localstack.services.stepfunctions.asl.eval.aws_execution_details import AWSExecutionDetails from localstack.services.stepfunctions.asl.eval.callback.callback import CallbackPoolManager from localstack.services.stepfunctions.asl.eval.contextobject.contex_object import ( ContextObject, @@ -27,13 +28,9 @@ class Environment: def __init__( - self, account_id: str, region_name: str, context_object_init: ContextObjectInitData + self, aws_execution_details: AWSExecutionDetails, context_object_init: ContextObjectInitData ): super(Environment, self).__init__() - - self.account_id = account_id - self.region_name = region_name - self._state_mutex = threading.RLock() self._program_state: Optional[ProgramState] = None self.program_state_event = threading.Event() @@ -41,6 +38,7 @@ def __init__( self.event_history: EventHistory = EventHistory() self.callback_pool_manager: CallbackPoolManager = CallbackPoolManager() + self.aws_execution_details: Final[AWSExecutionDetails] = aws_execution_details self._is_frame: bool = False self.heap: dict[str, Any] = dict() @@ -61,8 +59,7 @@ def as_frame_of(cls, env: Environment): StateMachine=env.context_object_manager.context_object["StateMachine"], ) frame = cls( - account_id=env.account_id, - region_name=env.region_name, + aws_execution_details=env.aws_execution_details, context_object_init=context_object_init, ) frame._is_frame = True diff --git a/localstack/services/stepfunctions/asl/parse/asl_parser.py b/localstack/services/stepfunctions/asl/parse/asl_parser.py index 0d21557ef5f99..260c326523c59 100644 --- a/localstack/services/stepfunctions/asl/parse/asl_parser.py +++ b/localstack/services/stepfunctions/asl/parse/asl_parser.py @@ -11,13 +11,13 @@ class AmazonStateLanguageParser(abc.ABC): @staticmethod - def parse(account_id: str, region_name: str, src: str) -> Program: + def parse(src: str) -> Program: input_stream = InputStream(src) lexer = ASLLexer(input_stream) stream = CommonTokenStream(lexer) parser = ASLParser(stream) parser._errHandler = antlr4.BailErrorStrategy() tree = parser.program_decl() - preprocessor = Preprocessor(account_id, region_name) + preprocessor = Preprocessor() program = preprocessor.visit(tree) return program diff --git a/localstack/services/stepfunctions/asl/parse/preprocessor.py b/localstack/services/stepfunctions/asl/parse/preprocessor.py index 53db7856a94eb..909c9a144ffc5 100644 --- a/localstack/services/stepfunctions/asl/parse/preprocessor.py +++ b/localstack/services/stepfunctions/asl/parse/preprocessor.py @@ -191,10 +191,6 @@ class Preprocessor(ASLParserVisitor): - def __init__(self, account_id: str, region_name: str): - self.account_id = account_id - self.region_name = region_name - @staticmethod def _inner_string_of(parse_tree: ParseTree) -> Optional[str]: if Antlr4Utils.is_terminal(parse_tree, ASLLexer.NULL): @@ -235,7 +231,7 @@ def visitState_type(self, ctx: ASLParser.State_typeContext) -> StateType: def visitResource_decl(self, ctx: ASLParser.Resource_declContext) -> Resource: inner_str = self._inner_string_of(parse_tree=ctx.keyword_or_string()) - return Resource.from_resource_arn(self.account_id, self.region_name, inner_str) + return Resource.from_resource_arn(inner_str) def visitEnd_decl(self, ctx: ASLParser.End_declContext) -> End: bool_child: ParseTree = ctx.children[-1] diff --git a/localstack/services/stepfunctions/asl/utils/boto_client.py b/localstack/services/stepfunctions/asl/utils/boto_client.py new file mode 100644 index 0000000000000..4f861e6c150a3 --- /dev/null +++ b/localstack/services/stepfunctions/asl/utils/boto_client.py @@ -0,0 +1,13 @@ +from botocore.client import BaseClient +from botocore.config import Config + +from localstack.aws.connect import connect_to + + +def boto_client_for(region: str, account: str, service: str) -> BaseClient: + return connect_to.get_client( + aws_access_key_id=account, + region_name=region, + service_name=service, + config=Config(parameter_validation=False), + ) diff --git a/localstack/services/stepfunctions/backend/execution.py b/localstack/services/stepfunctions/backend/execution.py index f667a94eaece2..dab492dc3c86b 100644 --- a/localstack/services/stepfunctions/backend/execution.py +++ b/localstack/services/stepfunctions/backend/execution.py @@ -23,6 +23,7 @@ TraceHeader, ) from localstack.aws.connect import connect_to +from localstack.services.stepfunctions.asl.eval.aws_execution_details import AWSExecutionDetails from localstack.services.stepfunctions.asl.eval.contextobject.contex_object import ( ContextObjectInitData, ) @@ -210,9 +211,6 @@ def start(self) -> None: raise InvalidName() # TODO. self.exec_worker = ExecutionWorker( - account_id=self.account_id, - region_name=self.region_name, - role_arn=self.role_arn, definition=self.state_machine.definition, input_data=self.input_data, exec_comm=Execution.BaseExecutionWorkerComm(self), @@ -229,6 +227,9 @@ def start(self) -> None: Name=self.state_machine.name, ), ), + aws_execution_details=AWSExecutionDetails( + account=self.account_id, region=self.region_name, role_arn=self.role_arn + ), ) self.exec_status = ExecutionStatus.RUNNING self._publish_execution_status_change_event() diff --git a/localstack/services/stepfunctions/backend/execution_worker.py b/localstack/services/stepfunctions/backend/execution_worker.py index 0569aa6132a7d..fd810e32e1f00 100644 --- a/localstack/services/stepfunctions/backend/execution_worker.py +++ b/localstack/services/stepfunctions/backend/execution_worker.py @@ -4,13 +4,13 @@ from typing import Final, Optional from localstack.aws.api.stepfunctions import ( - Arn, Definition, ExecutionStartedEventDetails, HistoryEventExecutionDataDetails, HistoryEventType, ) from localstack.services.stepfunctions.asl.component.program.program import Program +from localstack.services.stepfunctions.asl.eval.aws_execution_details import AWSExecutionDetails from localstack.services.stepfunctions.asl.eval.contextobject.contex_object import ( ContextObjectInitData, ) @@ -22,43 +22,36 @@ class ExecutionWorker: - role_arn: Final[Arn] - definition: Definition - input_data: Optional[dict] env: Optional[Environment] - _context_object_init: ContextObjectInitData - exec_comm: Final[ExecutionWorkerComm] + _definition: Definition + _input_data: Optional[dict] + _exec_comm: Final[ExecutionWorkerComm] + _context_object_init: Final[ContextObjectInitData] + _aws_execution_details: Final[AWSExecutionDetails] def __init__( self, - account_id: str, - region_name: str, - role_arn: Arn, definition: Definition, input_data: Optional[dict], context_object_init: ContextObjectInitData, + aws_execution_details: AWSExecutionDetails, exec_comm: ExecutionWorkerComm, ): - self.account_id = account_id - self.region_name = region_name - self.role_arn = role_arn - self.definition = definition - self.input_data = input_data - self.env = None + self._definition = definition + self._input_data = input_data + self._exec_comm = exec_comm self._context_object_init = context_object_init - self.exec_comm = exec_comm + self._aws_execution_details = aws_execution_details + self.env = None def _execution_logic(self): - program: Program = AmazonStateLanguageParser.parse( - self.account_id, self.region_name, self.definition - ) + program: Program = AmazonStateLanguageParser.parse(self._definition) self.env = Environment( - account_id=self.account_id, - region_name=self.region_name, + aws_execution_details=self._aws_execution_details, context_object_init=self._context_object_init, ) self.env.inp = copy.deepcopy( - self.input_data + self._input_data ) # The program will mutate the input_data, which is otherwise constant in regard to the execution value. self.env.event_history.add_event( @@ -69,14 +62,14 @@ def _execution_logic(self): inputDetails=HistoryEventExecutionDataDetails( truncated=False ), # Always False for api calls. - roleArn=self.role_arn, + roleArn=self._aws_execution_details.role_arn, ) ), ) program.eval(self.env) - self.exec_comm.terminated() + self._exec_comm.terminated() def start(self): Thread(target=self._execution_logic).start() diff --git a/localstack/services/stepfunctions/backend/utils.py b/localstack/services/stepfunctions/backend/utils.py deleted file mode 100644 index d89817f8c71a6..0000000000000 --- a/localstack/services/stepfunctions/backend/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -from botocore.config import Config - -from localstack.aws.connect import connect_to -from localstack.services.stepfunctions.asl.eval.environment import Environment - - -def get_boto_client(env: Environment, service: str): - return connect_to.get_client( - aws_access_key_id=env.account_id, - region_name=env.region_name, - service_name=service, - config=Config(parameter_validation=False), - ) diff --git a/localstack/services/stepfunctions/provider_v2.py b/localstack/services/stepfunctions/provider_v2.py index e649db89930cd..b7a2517b7cabd 100644 --- a/localstack/services/stepfunctions/provider_v2.py +++ b/localstack/services/stepfunctions/provider_v2.py @@ -119,9 +119,7 @@ def _idempotent_revision( state_machines: list[StateMachineInstance] = list( self.get_store(context).state_machines.values() ) - revisions = filter( - lambda state_machine: isinstance(state_machine, StateMachineRevision), state_machines - ) + revisions = filter(lambda sm: isinstance(sm, StateMachineRevision), state_machines) for state_machine in revisions: check = all( [ @@ -148,11 +146,11 @@ def _revision_by_name( return None @staticmethod - def _validate_definition(account_id: str, region_name: str, definition: str): + def _validate_definition(definition: str): # Validate # TODO: pass through static analyser. try: - AmazonStateLanguageParser.parse(account_id, region_name, definition) + AmazonStateLanguageParser.parse(definition) except Exception as ex: # TODO: add message from static analyser, this just helps the user debug issues in the derivation. raise InvalidDefinition(f"Error '{str(ex)}' in definition '{definition}'.") @@ -182,11 +180,7 @@ def create_state_machine( ) state_machine_definition: str = request["definition"] - StepFunctionsProvider._validate_definition( - account_id=context.account_id, - region_name=context.region, - definition=state_machine_definition, - ) + StepFunctionsProvider._validate_definition(definition=state_machine_definition) name: Optional[Name] = request["name"] arn = aws_stack_state_machine_arn( @@ -502,9 +496,7 @@ def update_state_machine( ) if definition is not None: - self._validate_definition( - account_id=context.account_id, region_name=context.region, definition=definition - ) + self._validate_definition(definition=definition) revision_id = state_machine.create_revision(definition=definition, role_arn=role_arn) diff --git a/localstack/services/stores.py b/localstack/services/stores.py index 2263b25fa5ba3..0b6c3da5ec46f 100644 --- a/localstack/services/stores.py +++ b/localstack/services/stores.py @@ -238,7 +238,7 @@ def __getitem__(self, region_name) -> BaseStoreType: store_obj._global = self._global store_obj._universal = self._universal - store_obj._service_name = self.service_name + store_obj.service_name = self.service_name store_obj._account_id = self.account_id store_obj._region_name = region_name From dfc56cce260a005323a00c0f9c5dd067f879494a Mon Sep 17 00:00:00 2001 From: MEPalma Date: Mon, 2 Oct 2023 13:25:50 +0200 Subject: [PATCH 13/14] split resource into static and runtime parts, minors --- .../common/error_name/failure_event.py | 4 - .../state_task/service/resource.py | 102 +++--------------- .../state_task/service/state_task_service.py | 43 +++++--- .../service/state_task_service_api_gateway.py | 9 +- .../service/state_task_service_aws_sdk.py | 15 ++- .../service/state_task_service_callback.py | 53 +++++++-- .../service/state_task_service_dynamodb.py | 15 +-- .../service/state_task_service_events.py | 19 ++-- .../service/state_task_service_lambda.py | 18 +++- .../service/state_task_service_sfn.py | 51 ++++++--- .../service/state_task_service_sns.py | 13 ++- .../service/state_task_service_sqs.py | 19 ++-- .../state_execution/state_task/state_task.py | 4 +- .../state_task/state_task_lambda.py | 8 +- 14 files changed, 201 insertions(+), 172 deletions(-) diff --git a/localstack/services/stepfunctions/asl/component/common/error_name/failure_event.py b/localstack/services/stepfunctions/asl/component/common/error_name/failure_event.py index 839eaba4d6318..94fc2eb9bb6de 100644 --- a/localstack/services/stepfunctions/asl/component/common/error_name/failure_event.py +++ b/localstack/services/stepfunctions/asl/component/common/error_name/failure_event.py @@ -3,7 +3,6 @@ from localstack.aws.api.stepfunctions import ExecutionFailedEventDetails, HistoryEventType from localstack.services.stepfunctions.asl.component.common.error_name.error_name import ErrorName from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails -from localstack.services.stepfunctions.asl.utils.encoding import to_json_str class FailureEvent: @@ -28,9 +27,6 @@ class FailureEventException(Exception): def __init__(self, failure_event: FailureEvent): self.failure_event = failure_event - def __str__(self) -> str: - return to_json_str(self.failure_event.event_details) - def get_execution_failed_event_details(self) -> Optional[ExecutionFailedEventDetails]: if self.failure_event.event_details is None: return None diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py index 1d4ac9abf9b47..3405562439913 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py @@ -72,19 +72,16 @@ def from_arn(cls, arn: str) -> ResourceARN: ) -class Resource(EvalComponent, abc.ABC): - class ResourceOutput: - resource_arn: Final[str] - partition: Final[str] - region: Final[str] - account: Final[str] - - def __init__(self, resource_arn: str, partition: str, region: str, account: str): - self.resource_arn = resource_arn - self.partition = partition - self.region = region - self.account = account +class ResourceRuntimePart: + account: Final[str] + region: Final[str] + + def __init__(self, account: str, region: str): + self.region = region + self.account = account + +class Resource(EvalComponent, abc.ABC): _region: Final[str] _account: Final[str] resource_arn: Final[str] @@ -107,56 +104,28 @@ def from_resource_arn(arn: str) -> Resource: case "states", _: return ServiceResource(resource_arn=resource_arn) - def _build_resource(self, env: Environment) -> Resource.ResourceOutput: + def _eval_runtime_part(self, env: Environment) -> ResourceRuntimePart: region = self._region if self._region else env.aws_execution_details.region account = self._account if self._account else env.aws_execution_details.account - return Resource.ResourceOutput( - resource_arn=self.resource_arn, - partition=self.partition, - region=region, + return ResourceRuntimePart( account=account, + region=region, ) def _eval_body(self, env: Environment) -> None: - resource_output = self._build_resource(env=env) - env.stack.append(resource_output) + runtime_part = self._eval_runtime_part(env=env) + env.stack.append(runtime_part) class ActivityResource(Resource): - class ActivityResourceOutput(Resource.ResourceOutput): - name: Final[str] - - def __init__(self, resource_arn: str, partition: str, region: str, account: str, name: str): - super().__init__( - resource_arn=resource_arn, partition=partition, region=region, account=account - ) - self.name = name - name: Final[str] def __init__(self, resource_arn: ResourceARN): super().__init__(resource_arn=resource_arn) self.name = resource_arn.name - def _build_resource(self, env: Environment) -> Resource.ResourceOutput: - resource_output: Resource.ResourceOutput = super()._build_resource(env=env) - activity_resource_output = ActivityResource.ActivityResourceOutput( - **vars(resource_output), name=self.name - ) - return activity_resource_output - class LambdaResource(Resource): - class LambdaResourceOutput(Resource.ResourceOutput): - function_name: Final[str] - - def __init__( - self, resource_arn: str, partition: str, region: str, account: str, function_name: str - ): - super().__init__( - resource_arn=resource_arn, partition=partition, region=region, account=account - ) - self.function_name = function_name function_name: Final[str] @@ -164,40 +133,8 @@ def __init__(self, resource_arn: ResourceARN): super().__init__(resource_arn=resource_arn) self.function_name = resource_arn.name - def _build_resource(self, env: Environment) -> Resource.ResourceOutput: - resource_output: Resource.ResourceOutput = super()._build_resource(env=env) - lambda_resource_output = LambdaResource.LambdaResourceOutput( - **vars(resource_output), function_name=self.function_name - ) - return lambda_resource_output - class ServiceResource(Resource): - class ServiceResourceOutput(Resource.ResourceOutput): - service_name: Final[str] - api_name: Final[str] - api_action: Final[str] - condition: Final[Optional[str]] - - def __init__( - self, - resource_arn: str, - partition: str, - region: str, - account: str, - service_name: str, - api_name: str, - api_action: str, - condition: Optional[str], - ): - super().__init__( - resource_arn=resource_arn, partition=partition, region=region, account=account - ) - self.service_name = service_name - self.api_name = api_name - self.api_action = api_action - self.condition = condition - service_name: Final[str] api_name: Final[str] api_action: Final[str] @@ -229,14 +166,3 @@ def __init__(self, resource_arn: ResourceARN): self.condition = ResourceCondition.Sync2 case unsupported: raise RuntimeError(f"Unsupported condition '{unsupported}'.") - - def _build_resource(self, env: Environment) -> Resource.ResourceOutput: - resource_output: Resource.ResourceOutput = super()._build_resource(env=env) - lambda_resource_output = ServiceResource.ServiceResourceOutput( - **vars(resource_output), - service_name=self.service_name, - api_name=self.api_name, - api_action=self.api_action, - condition=self.condition, - ) - return lambda_resource_output diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py index f861f265aea90..5429799749f96 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py @@ -20,6 +20,7 @@ StatesErrorNameType, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( + ResourceRuntimePart, ServiceResource, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.state_task import ( @@ -55,19 +56,22 @@ def _get_timed_out_failure_event(self) -> FailureEvent: @abc.abstractmethod def _eval_service_task( - self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, ): ... def _before_eval_execution( - self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + self, env: Environment, resource_runtime_part: ResourceRuntimePart, raw_parameters: dict ) -> None: - parameters_str = to_json_str(parameters) + parameters_str = to_json_str(raw_parameters) scheduled_event_details = TaskScheduledEventDetails( resource=self._get_sfn_resource(), resourceType=self._get_sfn_resource_type(), - region=resource.region, + region=resource_runtime_part.region, parameters=parameters_str, ) if not self.timeout.is_default_value(): @@ -92,7 +96,12 @@ def _before_eval_execution( ), ) - def _after_eval_execution(self, env: Environment) -> None: + def _after_eval_execution( + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, + ) -> None: output = env.stack[-1] env.event_history.add_event( hist_type_event=HistoryEventType.TaskSucceeded, @@ -107,17 +116,27 @@ def _after_eval_execution(self, env: Environment) -> None: ) def _eval_execution(self, env: Environment) -> None: - parameters = self._eval_parameters(env=env) - self.resource.eval(env=env) - resource_output: ServiceResource.ServiceResourceOutput = env.stack.pop() + resource_runtime_part: ResourceRuntimePart = env.stack.pop() + + raw_parameters = self._eval_parameters(env=env) - self._before_eval_execution(env=env, resource=resource_output, parameters=parameters) + self._before_eval_execution( + env=env, resource_runtime_part=resource_runtime_part, raw_parameters=raw_parameters + ) - normalised_parameters = self._normalised_parameters_bindings(parameters) - self._eval_service_task(env=env, resource=resource_output, parameters=normalised_parameters) + normalised_parameters = self._normalised_parameters_bindings(raw_parameters) + self._eval_service_task( + env=env, + resource_runtime_part=resource_runtime_part, + normalised_parameters=normalised_parameters, + ) - self._after_eval_execution(env=env) + self._after_eval_execution( + env=env, + resource_runtime_part=resource_runtime_part, + normalised_parameters=normalised_parameters, + ) @classmethod def for_service(cls, service_name: str) -> StateTaskService: diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py index fd1644491d96c..5eae9960a2268 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py @@ -25,7 +25,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( - ServiceResource, + ResourceRuntimePart, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, @@ -250,10 +250,13 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: ) def _eval_service_task( - self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, ): task_parameters: TaskParameters = select_from_typed_dict( - typed_dict=TaskParameters, obj=parameters + typed_dict=TaskParameters, obj=normalised_parameters ) method = task_parameters["Method"] diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py index 5c01f0d21c070..424baf1a1aa6b 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py @@ -12,7 +12,7 @@ StatesErrorNameType, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( - ServiceResource, + ResourceRuntimePart, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, @@ -107,14 +107,19 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: return super()._from_error(env=env, ex=ex) def _eval_service_task( - self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, ): api_client = boto_client_for( - region=resource.region, - account=resource.account, + region=resource_runtime_part.region, + account=resource_runtime_part.account, service=self._normalised_api_name, ) - response = getattr(api_client, self._normalised_api_action)(**parameters) or dict() + response = ( + getattr(api_client, self._normalised_api_action)(**normalised_parameters) or dict() + ) if response: response.pop("ResponseMetadata", None) env.stack.append(response) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py index 4668f906c632d..5a776dc656dc2 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py @@ -15,6 +15,7 @@ ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, + ResourceRuntimePart, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service import ( StateTaskService, @@ -39,7 +40,12 @@ def _get_sfn_resource(self) -> str: resource += f".{self.resource.condition}" return resource - def _wait_for_task_token(self, env: Environment) -> None: + def _wait_for_task_token( + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, + ) -> None: callback_id = env.context_object_manager.context_object["Task"]["Token"] callback_endpoint = env.callback_pool_manager.get(callback_id) @@ -79,12 +85,22 @@ def _wait_for_task_token(self, env: Environment) -> None: else: raise NotImplementedError(f"Unsupported CallbackOutcome type '{type(outcome)}'.") - def _sync(self, env: Environment) -> None: + def _sync( + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, + ) -> None: raise RuntimeError( f"Unsupported .sync callback procedure in resource {self.resource.resource_arn}" ) - def _sync2(self, env: Environment) -> None: + def _sync2( + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, + ) -> None: raise RuntimeError( f"Unsupported .sync:2 callback procedure in resource {self.resource.resource_arn}" ) @@ -113,7 +129,12 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: return self._get_callback_outcome_failure_event(ex=ex) return super()._from_error(env=env, ex=ex) - def _after_eval_execution(self, env: Environment) -> None: + def _after_eval_execution( + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, + ) -> None: if self._is_condition(): output = env.stack[-1] env.event_history.add_event( @@ -129,12 +150,28 @@ def _after_eval_execution(self, env: Environment) -> None: ) match self.resource.condition: case ResourceCondition.WaitForTaskToken: - self._wait_for_task_token(env=env) + self._wait_for_task_token( + env=env, + resource_runtime_part=resource_runtime_part, + normalised_parameters=normalised_parameters, + ) case ResourceCondition.Sync: - self._sync(env=env) + self._sync( + env=env, + resource_runtime_part=resource_runtime_part, + normalised_parameters=normalised_parameters, + ) case ResourceCondition.Sync2: - self._sync2(env=env) + self._sync2( + env=env, + resource_runtime_part=resource_runtime_part, + normalised_parameters=normalised_parameters, + ) case unsupported: raise NotImplementedError(f"Unsupported callback type '{unsupported}'.") - super()._after_eval_execution(env=env) + super()._after_eval_execution( + env=env, + resource_runtime_part=resource_runtime_part, + normalised_parameters=normalised_parameters, + ) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py index eb7fc57c45cbb..615d486e663f3 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py @@ -10,7 +10,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( - ServiceResource, + ResourceRuntimePart, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service import ( StateTaskService, @@ -125,14 +125,17 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: ) def _eval_service_task( - self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, ): - api_action = camel_to_snake_case(resource.api_action) + api_action = camel_to_snake_case(self.resource.api_action) dynamodb_client = boto_client_for( - region=resource.region, - account=resource.account, + region=resource_runtime_part.region, + account=resource_runtime_part.account, service="dynamodb", ) - response = getattr(dynamodb_client, api_action)(**parameters) + response = getattr(dynamodb_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) env.stack.append(response) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py index eeec727d22b5e..5d64491ede6e5 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py @@ -10,7 +10,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( - ServiceResource, + ResourceRuntimePart, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, @@ -75,21 +75,24 @@ def _normalised_request_parameters(env: Environment, parameters: dict): entry["Resources"] = resources def _eval_service_task( - self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, ): - self._normalised_request_parameters(env=env, parameters=parameters) - api_action = camel_to_snake_case(resource.api_action) + self._normalised_request_parameters(env=env, parameters=normalised_parameters) + api_action = camel_to_snake_case(self.resource.api_action) events_client = boto_client_for( - region=resource.region, - account=resource.account, + region=resource_runtime_part.region, + account=resource_runtime_part.account, service="events", ) - response = getattr(events_client, api_action)(**parameters) + response = getattr(events_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) # If the response from PutEvents contains a non-zero FailedEntryCount then the # Task state fails with the error EventBridge.FailedEntry. - if resource.api_action == "putevents": + if self.resource.api_action == "putevents": failed_entry_count = response.get("FailedEntryCount", 0) if failed_entry_count > 0: # TODO: pipe events' cause in the exception object. At them moment diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py index 81d75b99dfac8..82bcee79846e0 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py @@ -13,7 +13,7 @@ lambda_eval_utils, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( - ServiceResource, + ResourceRuntimePart, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, @@ -79,10 +79,18 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: ) def _eval_service_task( - self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, ): - if "Payload" in parameters: - parameters["Payload"] = lambda_eval_utils.to_payload_type(parameters["Payload"]) + if "Payload" in normalised_parameters: + normalised_parameters["Payload"] = lambda_eval_utils.to_payload_type( + normalised_parameters["Payload"] + ) lambda_eval_utils.exec_lambda_function( - env=env, parameters=parameters, region=resource.region, account=resource.account + env=env, + parameters=normalised_parameters, + region=resource_runtime_part.region, + account=resource_runtime_part.account, ) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py index c2c42b7a695f9..51e7a91d0dc28 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py @@ -23,7 +23,7 @@ StatesErrorNameType, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( - ServiceResource, + ResourceRuntimePart, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, @@ -134,8 +134,10 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: ) return super()._from_error(env=env, ex=ex) - def _normalised_parameters_bindings(self, parameters: dict[str, str]) -> dict[str, str]: - normalised_parameters = super()._normalised_parameters_bindings(parameters=parameters) + def _normalised_parameters_bindings(self, raw_parameters: dict[str, str]) -> dict[str, str]: + normalised_parameters = super()._normalised_parameters_bindings( + raw_parameters=raw_parameters + ) if self.resource.api_action.lower() == "startexecution": optional_input = normalised_parameters.get("input") @@ -163,28 +165,31 @@ def _replace_with_json_if_str(key: str) -> None: _replace_with_json_if_str("output") def _eval_service_task( - self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, ): - api_action = camel_to_snake_case(resource.api_action) + api_action = camel_to_snake_case(self.resource.api_action) sfn_client = boto_client_for( - region=resource.region, - account=resource.account, + region=resource_runtime_part.region, + account=resource_runtime_part.account, service="stepfunctions", ) - response = getattr(sfn_client, api_action)(**parameters) + response = getattr(sfn_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) - self._normalise_botocore_response(response.api_action, response) + self._normalise_botocore_response(self.resource.api_action, response) env.stack.append(response) def _sync_to_start_machine( self, env: Environment, - resource: ServiceResource.ServiceResourceOutput, + resource_runtime_part: ResourceRuntimePart, sync2_response: bool, ) -> None: sfn_client = boto_client_for( - region=resource.region, - account=resource.account, + region=resource_runtime_part.region, + account=resource_runtime_part.account, service="stepfunctions", ) submission_output: dict = env.stack.pop() @@ -228,16 +233,30 @@ def _has_terminated() -> Optional[dict]: env.stack.append(termination_output) - def _sync(self, env: Environment) -> None: + def _sync( + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, + ) -> None: match self.resource.api_action.lower(): case "startexecution": - self._sync_to_start_machine(env=env, sync2_response=False) + self._sync_to_start_machine( + env=env, resource_runtime_part=resource_runtime_part, sync2_response=False + ) case _: super()._sync(env=env) - def _sync2(self, env: Environment) -> None: + def _sync2( + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, + ) -> None: match self.resource.api_action.lower(): case "startexecution": - self._sync_to_start_machine(env=env, sync2_response=True) + self._sync_to_start_machine( + env=env, resource_runtime_part=resource_runtime_part, sync2_response=True + ) case _: super()._sync2(env=env) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py index a3514864a417d..cb8b27eeda801 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py @@ -10,7 +10,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( - ServiceResource, + ResourceRuntimePart, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, @@ -73,14 +73,17 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: return super()._from_error(env=env, ex=ex) def _eval_service_task( - self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, ): api_action = camel_to_snake_case(self.resource.api_action) sns_client = boto_client_for( - region=resource.region, - account=resource.account, + region=resource_runtime_part.region, + account=resource_runtime_part.account, service="sns", ) - response = getattr(sns_client, api_action)(**parameters) + response = getattr(sns_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) env.stack.append(response) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py index 9510af03b9248..01593d2e67b61 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py @@ -10,7 +10,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( - ServiceResource, + ResourceRuntimePart, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import ( StateTaskServiceCallback, @@ -59,21 +59,24 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: return super()._from_error(env=env, ex=ex) def _eval_service_task( - self, env: Environment, resource: ServiceResource.ServiceResourceOutput, parameters: dict + self, + env: Environment, + resource_runtime_part: ResourceRuntimePart, + normalised_parameters: dict, ): # TODO: Stepfunctions automatically dumps to json MessageBody's definitions. # Are these other similar scenarios? - if "MessageBody" in parameters: - message_body = parameters["MessageBody"] + if "MessageBody" in normalised_parameters: + message_body = normalised_parameters["MessageBody"] if message_body is not None and not isinstance(message_body, str): - parameters["MessageBody"] = to_json_str(message_body) + normalised_parameters["MessageBody"] = to_json_str(message_body) api_action = camel_to_snake_case(self.resource.api_action) sqs_client = boto_client_for( - region=resource.region, - account=resource.account, + region=resource_runtime_part.region, + account=resource_runtime_part.account, service="sqs", ) - response = getattr(sqs_client, api_action)(**parameters) + response = getattr(sqs_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) env.stack.append(response) diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task.py index e307dd8256119..3d30d462c9c14 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task.py @@ -50,8 +50,8 @@ def _get_supported_parameters(self) -> Optional[set[str]]: # noqa def _get_parameters_normalising_bindings(self) -> dict[str, str]: # noqa return dict() - def _normalised_parameters_bindings(self, parameters: dict[str, str]) -> dict[str, str]: - normalised_parameters = copy.deepcopy(parameters) + def _normalised_parameters_bindings(self, raw_parameters: dict[str, str]) -> dict[str, str]: + normalised_parameters = copy.deepcopy(raw_parameters) # Normalise bindings. parameter_normalisers = self._get_parameters_normalising_bindings() for parameter_key in list(normalised_parameters.keys()): diff --git a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py index 385cec06fb474..6d6305300e83a 100644 --- a/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py +++ b/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py @@ -28,6 +28,7 @@ ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( LambdaResource, + ResourceRuntimePart, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.state_task import ( StateTask, @@ -122,10 +123,13 @@ def _eval_execution(self, env: Environment) -> None: parameters["Payload"] = lambda_eval_utils.to_payload_type(parameters["Payload"]) self.resource.eval(env=env) - resource = env.stack.pop() + resource_runtime_part: ResourceRuntimePart = env.stack.pop() lambda_eval_utils.exec_lambda_function( - env=env, parameters=parameters, region=resource.region, account=resource.account + env=env, + parameters=parameters, + region=resource_runtime_part.region, + account=resource_runtime_part.account, ) # In lambda invocations, only payload is passed on as output. From 670df14bd5041f23ff2145267eb6e6bd137e4fe4 Mon Sep 17 00:00:00 2001 From: MEPalma Date: Fri, 6 Oct 2023 14:14:38 +0200 Subject: [PATCH 14/14] revert non default account and region --- localstack/constants.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/localstack/constants.py b/localstack/constants.py index 547da354d4cf7..8abca4bf08009 100644 --- a/localstack/constants.py +++ b/localstack/constants.py @@ -153,10 +153,10 @@ # Credentials used in the test suite # These can be overridden if the tests are being run against AWS # If a structured access key ID is used, it must correspond to the account ID -TEST_AWS_ACCOUNT_ID = os.getenv("TEST_AWS_ACCOUNT_ID") or "000000000001" -TEST_AWS_ACCESS_KEY_ID = os.getenv("TEST_AWS_ACCESS_KEY_ID") or "000000000001" +TEST_AWS_ACCOUNT_ID = os.getenv("TEST_AWS_ACCOUNT_ID") or DEFAULT_AWS_ACCOUNT_ID +TEST_AWS_ACCESS_KEY_ID = os.getenv("TEST_AWS_ACCESS_KEY_ID") or "test" TEST_AWS_SECRET_ACCESS_KEY = os.getenv("TEST_AWS_SECRET_ACCESS_KEY") or "test" -TEST_AWS_REGION_NAME = os.getenv("TEST_AWS_REGION") or "us-west-1" +TEST_AWS_REGION_NAME = os.getenv("TEST_AWS_REGION") or "us-east-1" # Additional credentials used in the test suite (when running cross-account tests) SECONDARY_TEST_AWS_ACCOUNT_ID = os.getenv("SECONDARY_TEST_AWS_ACCOUNT_ID") or "000000000002"