8000 StepFunctions: Multi-accounts compatibility (#9119) · codeperl/localstack@e7ddddd · GitHub
[go: up one dir, main page]

Skip to content

Commit e7ddddd

Browse files
StepFunctions: Multi-accounts compatibility (localstack#9119)
1 parent 0d27441 commit e7ddddd

27 files changed

+743
-402
lines changed

localstack/services/stepfunctions/asl/component/state/state_execution/state_map/item_reader/resource_eval/resource_eval_s3.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,42 @@
22

33
from typing import Callable, Final
44

5-
from botocore.config import Config
6-
7-
from localstack.aws.connect import connect_to
85
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.item_reader.resource_eval.resource_eval import (
96
ResourceEval,
107
)
8+
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
9+
ResourceRuntimePart,
10+
)
1111
from localstack.services.stepfunctions.asl.eval.environment import Environment
12+
from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for
1213
from localstack.utils.strings import camel_to_snake_case, to_str
1314

1415

1516
class ResourceEvalS3(ResourceEval):
1617
_HANDLER_REFLECTION_PREFIX: Final[str] = "_handle_"
17-
_API_ACTION_HANDLER_TYPE = Callable[[Environment], None]
18+
_API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart], None]
1819

1920
@staticmethod
20-
def _get_s3_client():
21-
# TODO: adjust following multi-account and invocation region changes.
22-
return connect_to.get_client(
23-
service_name="s3",
24-
config=Config(parameter_validation=False),
21+
def _get_s3_client(resource_runtime_part: ResourceRuntimePart):
22+
return boto_client_for(
23+
region=resource_runtime_part.region,
24+
account=resource_runtime_part.account,
25+
service="s3",
2526
)
2627

2728
@staticmethod
28-
def _handle_get_object(env: Environment) -> None:
29-
s3_client = ResourceEvalS3._get_s3_client()
29+
def _handle_get_object(env: Environment, resource_runtime_part: ResourceRuntimePart) -> None:
30+
s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part)
3031
parameters = env.stack.pop()
3132
response = s3_client.get_object(**parameters)
3233
content = to_str(response["Body"].read())
3334
env.stack.append(content)
3435

3536
@staticmethod
36-
def _handle_list_objects_v2(env: Environment) -> None:
37-
s3_client = ResourceEvalS3._get_s3_client()
37+
def _handle_list_objects_v2(
38+
env: Environment, resource_runtime_part: ResourceRuntimePart
39+
) -> None:
40+
s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part)
3841
parameters = env.stack.pop()
3942
response = s3_client.list_objects_v2(**parameters)
4043
contents = response["Contents"]
@@ -49,5 +52,7 @@ def _get_api_action_handler(self) -> ResourceEvalS3._API_ACTION_HANDLER_TYPE:
4952
return resolver_handler
5053

5154
def eval_resource(self, env: Environment) -> None:
55+
self.resource.eval(env=env)
56+
resource_runtime_part: ResourceRuntimePart = env.stack.pop()
5257
resolver_handler = self._get_api_action_handler()
53-
resolver_handler(env)
58+
resolver_handler(env, resource_runtime_part)

localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
from json import JSONDecodeError
33
from typing import Any, Final, Optional
44

5-
from botocore.config import Config
6-
75
from localstack.aws.api.lambda_ import InvocationResponse
8-
from localstack.aws.connect import connect_externally_to
96
from localstack.services.stepfunctions.asl.eval.environment import Environment
7+
from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for
108
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
119
from localstack.utils.collections import select_from_typed_dict
1210
from localstack.utils.run import to_str
@@ -22,8 +20,9 @@ def __init__(self, function_error: Optional[str], payload: str):
2220
self.payload = payload
2321

2422

25-
def exec_lambda_function(env: Environment, parameters: dict) -> None:
26-
lambda_client = connect_externally_to(config=Config(parameter_validation=False)).lambda_
23+
def exec_lambda_function(env: Environment, parameters: dict, region: str, account: str) -> None:
24+
lambda_client = boto_client_for(region=region, account=account, service="lambda")
25+
2726
invocation_resp: InvocationResponse = lambda_client.invoke(**parameters)
2827

2928
func_error: Optional[str] = invocation_resp.get("FunctionError")
@@ -35,7 +34,6 @@ def exec_lambda_function(env: Environment, parameters: dict) -> None:
3534
resp_payload = invocation_resp["Payload"].read()
3635
resp_payload_str = to_str(resp_payload)
3736
resp_payload_json: json = json.loads(resp_payload_str)
38-
# resp_payload_value = resp_payload_json if resp_payload_json is not None else dict()
3937
invocation_resp["Payload"] = resp_payload_json
4038

4139
response = select_from_typed_dict(typed_dict=InvocationResponse, obj=invocation_resp)

localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/resource.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from itertools import takewhile
55
from typing import Final, Optional
66

7-
from localstack.services.stepfunctions.asl.component.component import Component
8-
from localstack.utils.aws import aws_stack
7+
from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent
8+
from localstack.services.stepfunctions.asl.eval.environment import Environment
99

1010

1111
class ResourceCondition(str):
@@ -72,23 +72,30 @@ def from_arn(cls, arn: str) -> ResourceARN:
7272
)
7373

7474

75-
class Resource(Component, abc.ABC):
75+
class ResourceRuntimePart:
76+
account: Final[str]
77+
region: Final[str]
78+
79+
def __init__(self, account: str, region: str):
80+
self.region = region
81+
self.account = account
82+
83+
84+
class Resource(EvalComponent, abc.ABC):
85+
_region: Final[str]
86+
_account: Final[str]
7687
resource_arn: Final[str]
7788
partition: Final[str]
78-
region: Final[str]
79-
account: Final[str]
8089

8190
def __init__(self, resource_arn: ResourceARN):
91+
self._region = resource_arn.region
92+
self._account = resource_arn.account
8293
self.resource_arn = resource_arn.arn
8394
self.partition = resource_arn.partition
84-
self.region = resource_arn.region
85-
self.account = resource_arn.account
8695

8796
@staticmethod
8897
def from_resource_arn(arn: str) -> Resource:
8998
resource_arn = ResourceARN.from_arn(arn)
90-
if not resource_arn.region:
91-
resource_arn.region = aws_stack.get_region()
9299
match resource_arn.service, resource_arn.task_type:
93100
case "lambda", "function":
94101
return LambdaResource(resource_arn=resource_arn)
@@ -97,6 +104,18 @@ def from_resource_arn(arn: str) -> Resource:
97104
case "states", _:
98105
return ServiceResource(resource_arn=resource_arn)
99106

107+
def _eval_runtime_part(self, env: Environment) -> ResourceRuntimePart:
108+
region = self._region if self._region else env.aws_execution_details.region
109+
account = self._account if self._account else env.aws_execution_details.account
110+
return ResourceRuntimePart(
111+
account=account,
112+
region=region,
113+
)
114+
115+
def _eval_body(self, env: Environment) -> None:
116+
runtime_part = self._eval_runtime_part(env=env)
117+
env.stack.append(runtime_part)
118+
100119

101120
class ActivityResource(Resource):
102121
name: Final[str]
@@ -107,11 +126,12 @@ def __init__(self, resource_arn: ResourceARN):
107126

108127

109128
class LambdaResource(Resource):
129+
110130
function_name: Final[str]
111131

112132
def __init__(self, resource_arn: ResourceARN):
113133
super().__init__(resource_arn=resource_arn)
114-
self.function_name: str = resource_arn.name
134+
self.function_name = resource_arn.name
115135

116136

117137
class ServiceResource(Resource):

localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
StatesErrorNameType,
2121
)
2222
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
23+
ResourceRuntimePart,
2324
ServiceResource,
2425
)
2526
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.state_task import (
@@ -53,16 +54,23 @@ def _get_timed_out_failure_event(self) -> FailureEvent:
5354
)
5455

5556
@abc.abstractmethod
56-
def _eval_service_task(self, env: Environment, parameters: dict):
57+
def _eval_service_task(
58+
self,
59+
env: Environment,
60+
resource_runtime_part: ResourceRuntimePart,
61+
normalised_parameters: dict,
62+
):
5763
...
5864

59-
def _before_eval_execution(self, env: Environment, parameters: dict) -> None:
60-
parameters_str = to_json_str(parameters)
65+
def _before_eval_execution(
66+
self, env: Environment, resource_runtime_part: ResourceRuntimePart, raw_parameters: dict
67+
) -> None:
68+
parameters_str = to_json_str(raw_parameters)
6169

6270
scheduled_event_details = TaskScheduledEventDetails(
6371
resource=self._get_sfn_resource(),
6472
resourceType=self._get_sfn_resource_type(),
65-
region=self.resource.region,
73+
region=resource_runtime_part.region,
6674
parameters=parameters_str,
6775
)
6876
if not self.timeout.is_default_value():
@@ -87,7 +95,12 @@ def _before_eval_execution(self, env: Environment, parameters: dict) -> None:
8795
),
8896
)
8997

90-
def _after_eval_execution(self, env: Environment) -> None:
98+
def _after_eval_execution(
99+
self,
100+
env: Environment,
101+
resource_runtime_part: ResourceRuntimePart,
102+
normalised_parameters: dict,
103+
) -> None:
91104
output = env.stack[-1]
92105
env.event_history.add_event(
93106
hist_type_event=HistoryEventType.TaskSucceeded,
@@ -102,10 +115,24 @@ def _after_eval_execution(self, env: Environment) -> None:
102115
)
103116

104117
def _eval_execution(self, env: Environment) -> None:
105-
parameters = self._eval_parameters(env=env)
106-
self._before_eval_execution(env=env, parameters=parameters)
118+
self.resource.eval(env=env)
119+
resource_runtime_part: ResourceRuntimePart = env.stack.pop()
107120

108-
normalised_parameters = self._normalised_parameters_bindings(parameters)
109-
self._eval_service_task(env=env, parameters=normalised_parameters)
121+
raw_parameters = self._eval_parameters(env=env)
110122

111-
self._after_eval_execution(env=env)
123+
self._before_eval_execution(
124+
env=env, resource_runtime_part=resource_runtime_part, raw_parameters=raw_parameters
125+
)
126+
127+
normalised_parameters = self._normalised_parameters_bindings(raw_parameters)
128+
self._eval_service_task(
129+
env=env,
130+
resource_runtime_part=resource_runtime_part,
131+
normalised_parameters=normalised_parameters,
132+
)
133+
134+
self._after_eval_execution(
135+
env=env,
136+
resource_runtime_part=resource_runtime_part,
137+
normalised_parameters=normalised_parameters,
138+
)

localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
2525
FailureEvent,
2626
)
27+
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
28+
ResourceRuntimePart,
29+
)
2730
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import (
2831
StateTaskServiceCallback,
2932
)
@@ -98,7 +101,7 @@ def __init__(self, parameters: TaskParameters, response: Response):
98101
class StateTaskServiceApiGateway(StateTaskServiceCallback):
99102

100103
_SUPPORTED_API_PARAM_BINDINGS: Final[dict[str, set[str]]] = {
101-
SupportedApiCalls.invoke: set(TaskParameters.__required_keys__)
104+
SupportedApiCalls.invoke: set(TaskParameters.__required_keys__) # noqa
102105
}
103106

104107
_FORBIDDEN_HTTP_HEADERS_PREFIX: Final[set[str]] = {"X-Forwarded", "X-Amz", "X-Amzn"}
@@ -246,9 +249,14 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
246249
),
247250
)
248251

249-
def _eval_service_task(self, env: Environment, parameters: dict):
252+
def _eval_service_task(
253+
self,
254+
env: Environment,
255+
resource_runtime_part: ResourceRuntimePart,
256+
normalised_parameters: dict,
257+
):
250258
task_parameters: TaskParameters = select_from_typed_dict(
251-
typed_dict=TaskParameters, obj=parameters
259+
typed_dict=TaskParameters, obj=normalised_parameters
252260
)
253261

254262
method = task_parameters["Method"]

localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from botocore.config import Config
21
from botocore.exceptions import ClientError
32

43
from localstack.aws.api.stepfunctions import HistoryEventType, TaskFailedEventDetails
5-
from localstack.aws.connect import connect_externally_to
64
from localstack.aws.protocol.service_router import get_service_catalog
75
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
86
FailureEvent,
@@ -13,12 +11,16 @@
1311
from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name_type import (
1412
StatesErrorNameType,
1513
)
14+
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
15+
ResourceRuntimePart,
16+
)
1617
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import (
1718
StateTaskServiceCallback,
1819
)
1920
from localstack.services.stepfunctions.asl.component.state.state_props import StateProps
2021
from localstack.services.stepfunctions.asl.eval.environment import Environment
2122
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
23+
from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for
2224
from localstack.utils.common import camel_to_snake_case
2325

2426

@@ -104,11 +106,20 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
104106
return failure_event
105107
return super()._from_error(env=env, ex=ex)
106108

107-
def _eval_service_task(self, env: Environment, parameters: dict) -> None:
108-
api_client = connect_externally_to.get_client(
109-
service_name=self._normalised_api_name, config=Config(parameter_validation=False)
109+
def _eval_service_task(
110+
self,
111+
env: Environment,
112+
resource_runtime_part: ResourceRuntimePart,
113+
normalised_parameters: dict,
114+
):
115+
api_client = boto_client_for(
116+
region=resource_runtime_part.region,
117+
account=resource_runtime_part.account,
118+
service=self._normalised_api_name,
119+
)
120+
response = (
121+
getattr(api_client, self._normalised_api_action)(**normalised_parameters) or dict()
110122
)
111-
response = getattr(api_client, self._normalised_api_action)(**parameters) or dict()
112123
if response:
113124
response.pop("ResponseMetadata", None)
114125
env.stack.append(response)

0 commit comments

Comments
 (0)
0