8000 [SFN] Add support for Task Timeouts by MEPalma · Pull Request #8376 · localstack/localstack · GitHub
[go: up one dir, main page]

Skip to content

[SFN] Add support for Task Timeouts #8376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
base
  • Loading branch information
MEPalma committed May 8, 2023
commit 3b0c2e9380bc41f6c1dfd0037cdb87dd36542102
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,10 @@ def from_raw(cls, string_dollar: str, string_path_context_obj: str):
return cls(field=field, path_context_obj=path_context_obj)

def _eval_val(self, env: Environment) -> Any:
value = JSONPathUtils.extract_json(self.path_context_obj, env.context_object)
if self.path_context_obj.endswith("Task.Token"):
task_token = env.context_object_manager.update_task_token()
return env.callback_pool_manager.add(task_token)
value = JSONPathUtils.extract_json(
self.path_context_obj, env.context_object_manager.context_object
)
return value
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _eval_body(self, env: Environment) -> None:
),
)

env.context_object["State"] = State(
env.context_object_manager.context_object["State"] = State(
EnteredTime=datetime.datetime.now().isoformat(), Name=self.name, RetryCount=0
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ def from_state_props(self, state_props: StateProps) -> None:
raise ValueError(f"Missing ItemProcessor definition in props '{state_props}'.")

def _eval_body(self, env: Environment) -> None:
env.context_object["Map"] = Map(Item=Item(Index=-1, Value="Unsupported"))
env.context_object_manager.context_object["Map"] = Map(
Item=Item(Index=-1, Value="Unsupported")
)
super(StateMap, self)._eval_body(env=env)
env.context_object["Map"] = None
env.context_object_manager.context_object["Map"] = None

def _eval_execution(self, env: Environment) -> None:
# Reduce the input to the list of items.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import abc
from typing import Final, TypedDict
from typing import Final, Optional, TypedDict

from localstack.services.stepfunctions.asl.component.component import Component
from localstack.utils.aws import aws_stack


class ResourceCondition(str):
WaitForTaskToken = "waitForTaskToken"


class ResourceARN(TypedDict):
partition: str
service: str
Expand Down Expand Up @@ -90,6 +94,7 @@ class ServiceResource(Resource):
service_name: Final[str]
api_name: Final[str]
api_action: Final[str]
condition: Final[Optional[str]]

def __init__(
self,
Expand All @@ -105,4 +110,16 @@ def __init__(
)
self.service_name = service_name
self.api_name = api_name
self.api_action = resource_arn.split(":")[-1]

arn_parts = resource_arn.split(":")
tail_part = arn_parts[-1]
tail_parts = tail_part.split(".")
self.api_action = tail_parts[0]

self.condition = None
if len(tail_parts) > 1:
match tail_parts[-1]:
case "waitForTaskToken":
self.condition = ResourceCondition.WaitForTaskToken
case unsupported:
raise RuntimeError(f"Unsupported condition '{unsupported}'.")
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name_type import (
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service import (
StateTaskService,
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import (
StateTaskServiceCallback,
)
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
Expand All @@ -28,13 +28,11 @@
from localstack.utils.common import camel_to_snake_case


class StateTaskServiceAwsSdk(StateTaskService):
class StateTaskServiceAwsSdk(StateTaskServiceCallback):
def _get_resource_type(self) -> str:
return f"{self.resource.service_name}:{self.resource.api_name}"

def _eval_execution(self, env: Environment) -> None:
super()._eval_execution(env=env)

def _eval_service_task(self, env: Environment) -> None:
api_name = self.resource.api_name
api_action = camel_to_snake_case(self.resource.api_action)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import abstractmethod

from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service import (
StateTaskService,
)
from localstack.services.stepfunctions.asl.eval.environment import Environment


class StateTaskServiceCallback(StateTaskService):
@abstractmethod
def _eval_service_task(self, env: Environment):
...

def _eval_execution(self, env: Environment) -> None:
self._eval_service_task(env=env)
if self.resource.condition is not None:
callback_id = env.context_object_manager.context_object["Task"]["Token"]
callback_endpoint = env.callback_pool_manager.get(callback_id)
callback_endpoint.wait() # TODO: implement timeout.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name_type import (
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service import (
StateTaskService,
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import (
StateTaskServiceCallback,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.state_task_lambda import (
LambdaFunctionErrorException,
Expand All @@ -32,7 +32,7 @@
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str


class StateTaskServiceLambda(StateTaskService, StateTaskLambda):
class StateTaskServiceLambda(StateTaskServiceCallback, StateTaskLambda):
@staticmethod
def _error_cause_from_client_error(client_error: ClientError) -> tuple[str, str]:
error_code: str = client_error.response["Error"]["Code"]
Expand Down Expand Up @@ -76,7 +76,7 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
),
)

def _eval_execution(self, env: Environment) -> None:
def _eval_service_task(self, env: Environment) -> None:
parameters = self._eval_parameters(env=env)
parameters_str = to_json_str(parameters)
env.event_history.add_event(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
FailureEvent,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service import (
StateTaskService,
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import (
StateTaskServiceCallback,
)
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
Expand All @@ -26,7 +26,7 @@
from localstack.utils.strings import camel_to_snake_case


class StateTaskServiceSqs(StateTaskService):
class StateTaskServiceSqs(StateTaskServiceCallback):
_ERROR_NAME_CLIENT: Final[str] = "SQS.SdkClientException"
_ERROR_NAME_AWS: Final[str] = "SQS.AmazonSQSException"

Expand Down Expand Up @@ -89,7 +89,7 @@ def _eval_parameters(self, env: Environment) -> dict:

return parameters

def _eval_execution(self, env: Environment) -> None:
def _eval_service_task(self, env: Environment) -> None:
parameters = self._eval_parameters(env=env)

parameters_str = to_json_str(parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Resource,
)
from localstack.services.stepfunctions.asl.component.state.state_props import StateProps
from localstack.services.stepfunctions.asl.eval.contextobject.contex_object import Task
from localstack.services.stepfunctions.asl.eval.environment import Environment


Expand Down Expand Up @@ -61,6 +60,5 @@ def from_state_props(self, state_props: StateProps) -> None:
self.resource = state_props.get(Resource)

def _eval_body(self, env: Environment) -> None:
env.context_object["Task"] = Task(Token="Unsupported")
super(StateTask, self)._eval_body(env=env)
env.context_object["Task"] = None
env.context_object_manager.context_object["Task"] = None
Empty file.
107 changes: 107 additions & 0 deletions localstack/services/stepfunctions/asl/eval/callback/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import abc
from collections import OrderedDict
from threading import Event
from typing import Final, Optional

from localstack.utils.strings import long_uid

CallbackId = str


class CallbackOutcome(abc.ABC):
callback_id: Final[CallbackId]

def __init__(self, callback_id: str):
self.callback_id = callback_id


class CallbackOutcomeSuccess(CallbackOutcome):
output: Final[str]

def __init__(self, callback_id: CallbackId, output: str):
super().__init__(callback_id=callback_id)
self.output = output


class CallbackOutcomeFailure(CallbackOutcome):
error: Final[str]
cause: Final[str]

def __init__(self, callback_id: CallbackId, error: str, cause: str):
super().__init__(callback_id=callback_id)
self.error = error
self.cause = cause


class CallbackConsumerError(abc.ABC):
...


class CallbackConsumerTimeout(CallbackConsumerError):
pass


class CallbackConsumerLeft(CallbackConsumerError):
pass


class CallbackEndpoint:
callback_id: Final[CallbackId]
_notify_event: Final[Event]
_outcome: Optional[CallbackOutcome]
consumer_error: Optional[CallbackConsumerError]

def __init__(self, callback_id: CallbackId):
self.callback_id = callback_id
self._notify_event = Event()
self._outcome = None

def notify(self, outcome: CallbackOutcome):
self._outcome = outcome
self._notify_event.set()

def wait(self, timeout: Optional[float] = None) -> Optional[CallbackOutcome]:
self._notify_event.wait(timeout=timeout)
return self._outcome

def report(self, consumer_error: CallbackConsumerError) -> None:
self.consumer_error = consumer_error


class CallbackNotifyConsumerError(RuntimeError):
callback_consumer_error: CallbackConsumerError

def __init__(self, callback_consumer_error: CallbackConsumerError):
self.callback_consumer_error = callback_consumer_error


class CallbackPoolManager:
_pool: dict[CallbackId, CallbackEndpoint]

def __init__(self):
self._pool = OrderedDict()

def get(self, callback_id: CallbackId) -> Optional[CallbackEndpoint]:
return self._pool.get(callback_id)

def add(self, callback_id: CallbackId) -> CallbackEndpoint:
if callback_id in self._pool:
raise ValueError("Duplicate callback token id value.")
callback_endpoint = CallbackEndpoint(callback_id=callback_id)
self._pool[callback_id] = callback_endpoint
return callback_endpoint

def generate(self) -> CallbackEndpoint:
return self.add(long_uid())

def notify(self, callback_id: CallbackId, outcome: CallbackOutcome) -> bool:
callback_endpoint = self._pool.pop(callback_id, None)
if callback_endpoint is None:
return False

consumer_error: Optional[CallbackConsumerError] = callback_endpoint.consumer_error
if consumer_error is not None:
raise CallbackNotifyConsumerError(callback_consumer_error=consumer_error)

callback_endpoint.notify(outcome=outcome)
return True
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Optional, TypedDict
from typing import Final, Optional, TypedDict

from localstack.utils.strings import long_uid


class Execution(TypedDict):
Expand Down Expand Up @@ -43,6 +45,18 @@ class ContextObject(TypedDict):
Map: Optional[Map] # Only available when processing a Map state.


class ContextObjectManager:
context_object: Final[ContextObject]

def __init__(self, context_object: ContextObject):
self.context_object = context_object

def update_task_token(self) -> str:
new_token = long_uid()
self.context_object["Task"] = Task(Token=new_token)
return new_token


class ContextObjectInitData(TypedDict):
Execution: Execution
StateMachine: StateMachine
23 changes: 14 additions & 9 deletions localstack/services/stepfunctions/asl/eval/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from typing import Any, Optional

from localstack.aws.api.stepfunctions import ExecutionFailedEventDetails, Timestamp
from localstack.services.stepfunctions.asl.eval.callback.callback import CallbackPoolManager
from localstack.services.stepfunctions.asl.eval.contextobject.contex_object import (
ContextObject,
ContextObjectInitData,
ContextObjectManager,
)
from localstack.services.stepfunctions.asl.eval.event.event_history import EventHistory
from localstack.services.stepfunctions.asl.eval.programstate.program_ended import ProgramEnded
Expand All @@ -29,29 +31,32 @@ def __init__(self, context_object_init: ContextObjectInitData):
self._frames: list[Environment] = list()

self.event_history: EventHistory = EventHistory()
self.callback_pool_manager: CallbackPoolManager = CallbackPoolManager()

self.heap: dict[str, Any] = dict()
self.stack: list[Any] = list()
self.inp: Optional[Any] = None

self.context_object: ContextObject = ContextObject(
Execution=context_object_init["Execution"],
StateMachine=context_object_init["StateMachine"],
State=None,
Task=None,
Map=None,
self.context 8785 _object_manager: ContextObjectManager = ContextObjectManager(
context_object=ContextObject(
Execution=context_object_init["Execution"],
StateMachine=context_object_init["StateMachine"],
State=None,
Task=None,
Map=None,
)
)

@classmethod
def as_frame_of(cls, env: Environment):
context_object_init = ContextObjectInitData(
Execution=env.context_object["Execution"],
StateMachine=env.context_object["StateMachine"],
Execution=env.context_object_manager.context_object["Execution"],
StateMachine=env.context_object_manager.context_object["StateMachine"],
)
frame = cls(context_object_init=context_object_init)
frame.heap = env.heap
frame.event_history = env.event_history
frame.context_object = env.context_object
frame.context_object = env.context_object_manager.context_object
return frame

@property
Expand Down
Loading
0