8000 StepFunctions: Multi-accounts compatibility by viren-nadkarni · Pull Request #9119 · localstack/localstack · GitHub
[go: up one dir, main page]

Skip to content

StepFunctions: Multi-accounts compatibility #9119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
79ec46e
[DEBUG] Use non default account and region
viren-nadkarni Aug 29, 2023
5cdc0bf
Add account and region context to executions
viren-nadkarni Sep 12, 2023
cfc9743
Use execution context for internal DDB client
viren-nadkarni Sep 12, 2023
2682b40
Build ARNs with proper account ID and region name
viren-nadkarni Sep 12, 2023
14eb9a2
Implement account ID namespacing for legacy stepfunctions provider
viren-nadkarni Sep 12, 2023
f4c8a85
Take account ID and region into consideration for internal requests
viren-nadkarni Sep 13, 2023
9cf16d0
Run the legacy SF tests in a hardcoded region
viren-nadkarni Sep 13, 2023
b2917c0
Use an alternative way to pass the context
viren-nadkarni Sep 13, 2023
e43191c
Use request context account ID and region during program construction
viren-nadkarni Sep 14, 2023
5ef3a98
Merge branch 'master' into stepfunctions-multi-accounts
viren-nadkarni Sep 14, 2023
aad9b80
Maintain backward compatibility
viren-nadkarni Sep 14, 2023
74d02eb
Use proper client utility for eventbridge
viren-nadkarni Sep 14, 2023
a4b0954
Merge branch 'master' into stepfunctions-multi-accounts
MEPalma Sep 29, 2023
54ba905
resources as evaluation components
MEPalma Oct 2, 2023
528eb7e
Merge branch 'master' into stepfunctions-multi-accounts
MEPalma Oct 2, 2023
dfc56cc
split resource into static and runtime parts, minors
MEPalma Oct 2, 2023
5d7cfe4
Merge branch 'master' into stepfunctions-multi-accounts
MEPalma Oct 4, 2023
c75017b
conflicts resolution, add resource evaluation to s3 distributed map r…
MEPalma Oct 4, 2023
90fd710
Merge branch 'master' into stepfunctions-multi-accounts
MEPalma Oct 6, 2023
670df14
revert non default account and region
MEPalma Oct 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,42 @@

from typing import Callable, Final

from botocore.config import Config

from localstack.aws.connect import connect_to
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.item_reader.resource_eval.resource_eval import (
ResourceEval,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceRuntimePart,
)
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for
from localstack.utils.strings import camel_to_snake_case, to_str


class ResourceEvalS3(ResourceEval):
_HANDLER_REFLECTION_PREFIX: Final[str] = "_handle_"
_API_ACTION_HANDLER_TYPE = Callable[[Environment], None]
_API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart], None]

@staticmethod
def _get_s3_client():
# TODO: adjust following multi-account and invocation region changes.
return connect_to.get_client(
service_name="s3",
config=Config(parameter_validation=False),
def _get_s3_client(resource_runtime_part: ResourceRuntimePart):
return boto_client_for(
region=resource_runtime_part.region,
account=resource_runtime_part.account,
service="s3",
)

@staticmethod
def _handle_get_object(env: Environment) -> None:
s3_client = ResourceEvalS3._get_s3_client()
def _handle_get_object(env: Environment, resource_runtime_part: ResourceRuntimePart) -> None:
s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part)
parameters = env.stack.pop()
response = s3_client.get_object(**parameters)
content = to_str(response["Body"].read())
env.stack.append(content)

@staticmethod
def _handle_list_objects_v2(env: Environment) -> None:
s3_client = ResourceEvalS3._get_s3_client()
def _handle_list_objects_v2(
env: Environment, resource_runtime_part: ResourceRuntimePart
) -> None:
s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part)
parameters = env.stack.pop()
response = s3_client.list_objects_v2(**parameters)
contents = response["Contents"]
Expand All @@ -49,5 +52,7 @@ def _get_api_action_handler(self) -> ResourceEvalS3._API_ACTION_HANDLER_TYPE:
return resolver_handler

def eval_resource(self, env: Environment) -> None:
self.resource.eval(env=env)
resource_runtime_part: ResourceRuntimePart = env.stack.pop()
resolver_handler = self._get_api_action_handler()
resolver_handler(env)
resolver_handler(env, resource_runtime_part)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from json import JSONDecodeError
from typing import Any, Final, Optional

from botocore.config import Config

from localstack.aws.api.lambda_ import InvocationResponse
from localstack.aws.connect import connect_externally_to
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
from localstack.utils.collections import select_from_typed_dict
from localstack.utils.run import to_str
Expand All @@ -22,8 +20,9 @@ def __init__(self, function_error: Optional[str], payload: str):
self.payload = payload


def exec_lambda_function(env: Environment, parameters: dict) -> None:
lambda_client = connect_externally_to(config=Config(parameter_validation=False)).lambda_
def exec_lambda_function(env: Environment, parameters: dict, region: str, account: str) -> None:
lambda_client = boto_client_for(region=region, account=account, service="lambda")

invocation_resp: InvocationResponse = lambda_client.invoke(**parameters)

func_error: Optional[str] = invocation_resp.get("FunctionError")
Expand All @@ -35,7 +34,6 @@ def exec_lambda_function(env: Environment, parameters: dict) -> None:
resp_payload = invocation_resp["Payload"].read()
resp_payload_str = to_str(resp_payload)
resp_payload_json: json = json.loads(resp_payload_str)
# resp_payload_value = resp_payload_json if resp_payload_json is not None else dict()
invocation_resp["Payload"] = resp_payload_json

response = select_from_typed_dict(typed_dict=InvocationResponse, obj=invocation_resp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from itertools import takewhile
from typing import Final, Optional

from localstack.services.stepfunctions.asl.component.component import Component
from localstack.utils.aws import aws_stack
from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent
from localstack.services.stepfunctions.asl.eval.environment import Environment


class ResourceCondition(str):
Expand Down Expand Up @@ -72,23 +72,30 @@ def from_arn(cls, arn: str) -> ResourceARN:
)


class Resource(Component, abc.ABC):
class ResourceRuntimePart:
account: Final[str]
region: Final[str]

def __init__(self, account: str, region: str):
self.region = region
self.account = account


class Resource(EvalComponent, abc.ABC):
_region: Final[str]
_account: Final[str]
resource_arn: Final[str]
partition: Final[str]
region: Final[str]
account: Final[str]

def __init__(self, resource_arn: ResourceARN):
self._region = resource_arn.region
self._account = resource_arn.account
self.resource_arn = resource_arn.arn
self.partition = resource_arn.partition
self.region = resource_arn.region
self.account = resource_arn.account

@staticmethod
def from_resource_arn(arn: str) -> Resource:
resource_arn = ResourceARN.from_arn(arn)
if not resource_arn.region:
resource_arn.region = aws_stack.get_region()
match resource_arn.service, resource_arn.task_type:
case "lambda", "function":
return LambdaResource(resource_arn=resource_arn)
Expand All @@ -97,6 +104,18 @@ def from_resource_arn(arn: str) -> Resource:
case "states", _:
return ServiceResource(resource_arn=resource_arn)

def _eval_runtime_part(self, env: Environment) -> ResourceRuntimePart:
region = self._region if self._region else env.aws_execution_details.region
account = self._account if self._account else env.aws_execution_details.account
return ResourceRuntimePart(
account=account,
region=region,
)

def _eval_body(self, env: Environment) -> None:
runtime_part = self._eval_runtime_part(env=env)
env.stack.append(runtime_part)


class ActivityResource(Resource):
name: Final[str]
Expand All @@ -107,11 +126,12 @@ def __init__(self, resource_arn: ResourceARN):


class LambdaResource(Resource):

function_name: Final[str]

def __init__(self, resource_arn: ResourceARN):
super().__init__(resource_arn=resource_arn)
self.function_name: str = resource_arn.name
self.function_name = resource_arn.name


class ServiceResource(Resource):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceRuntimePart,
ServiceResource,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.state_task import (
Expand Down Expand Up @@ -53,16 +54,23 @@ def _get_timed_out_failure_event(self) -> FailureEvent:
)

@abc.abstractmethod
def _eval_service_task(self, env: Environment, parameters: dict):
def _eval_service_task(
self,
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
):
...

def _before_eval_execution(self, env: Environment, parameters: dict) -> None:
parameters_str = to_json_str(parameters)
def _before_eval_execution(
self, env: Environment, resource_runtime_part: ResourceRuntimePart, raw_parameters: dict
) -> None:
parameters_str = to_json_str(raw_parameters)

scheduled_event_details = TaskScheduledEventDetails(
resource=self._get_sfn_resource(),
resourceType=self._get_sfn_resource_type(),
region=self.resource.region,
region=resource_runtime_part.region,
parameters=parameters_str,
)
if not self.timeout.is_default_value():
Expand All @@ -87,7 +95,12 @@ def _before_eval_execution(self, env: Environment, parameters: dict) -> None:
),
)

def _after_eval_execution(self, env: Environment) -> None:
def _after_eval_execution(
self,
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
) -> None:
output = env.stack[-1]
env.event_history.add_event(
hist_type_event=HistoryEventType.TaskSucceeded,
Expand All @@ -102,10 +115,24 @@ def _after_eval_execution(self, env: Environment) -> None:
)

def _eval_execution(self, env: Environment) -> None:
parameters = self._eval_parameters(env=env)
self._before_eval_execution(env=env, parameters=parameters)
self.resource.eval(env=env)
resource_runtime_part: ResourceRuntimePart = env.stack.pop()

normalised_parameters = self._normalised_parameters_bindings(parameters)
self._eval_service_task(env=env, parameters=normalised_parameters)
raw_parameters = self._eval_parameters(env=env)

self._after_eval_execution(env=env)
self._before_eval_execution(
env=env, resource_runtime_part=resource_runtime_part, raw_parameters=raw_parameters
)

normalised_parameters = self._normalised_parameters_bindings(raw_parameters)
self._eval_service_task(
env=env,
resource_runtime_part=resource_runtime_part,
normalised_parameters=normalised_parameters,
)

self._after_eval_execution(
env=env,
resource_runtime_part=resource_runtime_part,
normalised_parameters=normalised_parameters,
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
FailureEvent,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceRuntimePart,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import (
StateTaskServiceCallback,
)
Expand Down Expand Up @@ -98,7 +101,7 @@ def __init__(self, parameters: TaskParameters, response: Response):
class StateTaskServiceApiGateway(StateTaskServiceCallback):

_SUPPORTED_API_PARAM_BINDINGS: Final[dict[str, set[str]]] = {
SupportedApiCalls.invoke: set(TaskParameters.__required_keys__)
SupportedApiCalls.invoke: set(TaskParameters.__required_keys__) # noqa
}

_FORBIDDEN_HTTP_HEADERS_PREFIX: Final[set[str]] = {"X-Forwarded", "X-Amz", "X-Amzn"}
Expand Down Expand Up @@ -246,9 +249,14 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
),
)

def _eval_service_task(self, env: Environment, parameters: dict):
def _eval_service_task(
self,
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
):
task_parameters: TaskParameters = select_from_typed_dict(
typed_dict=TaskParameters, obj=parameters
typed_dict=TaskParameters, obj=normalised_parameters
)

method = task_parameters["Method"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from botocore.config import Config
from botocore.exceptions import ClientError

from localstack.aws.api.stepfunctions import HistoryEventType, TaskFailedEventDetails
from localstack.aws.connect import connect_externally_to
from localstack.aws.protocol.service_router import get_service_catalog
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
FailureEvent,
Expand All @@ -13,12 +11,16 @@
from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name_type import (
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceRuntimePart,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import (
StateTaskServiceCallback,
)
from localstack.services.stepfunctions.asl.component.state.state_props import StateProps
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for
from localstack.utils.common import camel_to_snake_case


Expand Down Expand Up @@ -104,11 +106,20 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
return failure_event
return super()._from_error(env=env, ex=ex)

def _eval_service_task(self, env: Environment, parameters: dict) -> None:
api_client = connect_externally_to.get_client(
service_name=self._normalised_api_name, config=Config(parameter_validation=False)
def _eval_service_task(
self,
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
):
api_client = boto_client_for(
region=resource_runtime_part.region,
account=resource_runtime_part.account,
service=self._normalised_api_name,
)
response = (
getattr(api_client, self._normalised_api_action)(**normalised_parameters) or dict()
)
response = getattr(api_client, self._normalised_api_action)(**parameters) or dict()
if response:
response.pop("ResponseMetadata", None)
env.stack.append(response)
Loading
0