8000 [SFN] Add support for Task Timeouts by MEPalma · Pull Request #8376 · localstack/localstack · GitHub
[go: up one dir, main page]

Skip to content

[SFN] Add support for Task Timeouts #8376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixes
  • Loading branch information
MEPalma committed May 10, 2023
commit 2f392e707e1df31da3469116df302d0556ab3fed
2 changes: 1 addition & 1 deletion localstack/services/stepfunctions/asl/antlr/ASLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ result_decl
;

result_path_decl
: RESULTPATH COLON keyword_or_string // TODO keywords too?
: RESULTPATH COLON (keyword_or_string | NULL) // TODO keywords too?
;

output_path_decl
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated from /Users/mep/LocalStack/localstack/localstack/services/stepfunctions/asl/antlr/ASLIntrinsicLexer.g4 by ANTLR 4.11.1
# Generated from /Users/mep/LocalStack/localstack/localstack/services/stepfunctions/asl/antlr/ASLIntrinsicLexer.g4 by ANTLR 4.12.0
from antlr4 import *
from io import StringIO
import sys
Expand Down Expand Up @@ -247,7 +247,7 @@ class ASLIntrinsicLexer(Lexer):

def __init__(self, input=None, output:TextIO = sys.stdout):
super().__init__(input, output)
self.checkVersion("4.11.1")
self.checkVersion("4.12.0")
self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache())
self._actions = None
self._predicates = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated from /Users/mep/LocalStack/localstack/localstack/services/stepfunctions/asl/antlr/ASLIntrinsicParser.g4 by ANTLR 4.11.1
# Generated from /Users/mep/LocalStack/localstack/localstack/services/stepfunctions/asl/antlr/ASLIntrinsicParser.g4 by ANTLR 4.12.0
# encoding: utf-8
from antlr4 import *
from io import StringIO
Expand Down Expand Up @@ -143,7 +143,7 @@ class ASLIntrinsicParser ( Parser ):

def __init__(self, input:TokenStream, output:TextIO = sys.stdout):
super().__init__(input, output)
self.checkVersion("4.11.1")
self.checkVersion("4.12.0")
self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache)
self._predicates = None

Expand Down Expand Up @@ -352,7 +352,7 @@ def state_fun_name(self):
self.enterOuterAlt(localctx, 1)
self.state = 31
_la = self._input.LA(1)
if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 274876858368) != 0):
if not((((_la) & ~0x3f) == 0 and ((1 << _la) & 274876858368) != 0)):
self._errHandler.recoverInline(self)
else:
self._errHandler.reportMatch(self)
Expand Down Expand Up @@ -1260,7 +1260,7 @@ def json_path_query(self, _p:int=0):
if token in [9, 10, 15]:
self.state = 89
_la = self._input.LA(1)
if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 34304) != 0):
if not((((_la) & ~0x3f) == 0 and ((1 << _la) & 34304) != 0)):
self._errHandler.recoverInline(self)
else:
self._errHandler.reportMatch(self)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated from /Users/mep/LocalStack/localstack/localstack/services/stepfunctions/asl/antlr/ASLIntrinsicParser.g4 by ANTLR 4.11.1
# Generated from /Users/mep/LocalStack/localstack/localstack/services/stepfunctions/asl/antlr/ASLIntrinsicParser.g4 by ANTLR 4.12.0
from antlr4 import *
if __name__ is not None and "." in __name__:
from .ASLIntrinsicParser import ASLIntrinsicParser
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated from /Users/mep/LocalStack/localstack/localstack/services/stepfunctions/asl/antlr/ASLIntrinsicParser.g4 by ANTLR 4.11.1
# Generated from /Users/mep/LocalStack/localstack/localstack/services/stepfunctions/asl/antlr/ASLIntrinsicParser.g4 by ANTLR 4.12.0
from antlr4 import *
if __name__ is not None and "." in __name__:
from .ASLIntrinsicParser import ASLIntrinsicParser
Expand Down

Large diffs are not rendered by default.

1,101 changes: 558 additions & 543 deletions

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Final
from typing import Final, Optional

from jsonpath_ng import parse

Expand All @@ -11,11 +11,15 @@ class ResultPath(EvalComponent):
DEFAULT_PATH: Final[str] = "$"

def __init__(self, result_path_src: str):
self.result_path_src: Final[str] = result_path_src
self.result_path_src: Final[Optional[str]] = result_path_src

def _eval_body(self, env: Environment) -> None:
result = env.stack.pop()

if self.result_path_src is None:
return

result_expr = parse(self.result_path_src)
result = copy.deepcopy(env.stack.pop())
if env.inp is None:
env.inp = dict()
env.inp = result_expr.update_or_create(env.inp, result)
env.inp = result_expr.update_or_create(env.inp, copy.deepcopy(result))
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ def from_raw(cls, string_dollar: str, string_path_context_obj: str):
def _eval_val(self, env: Environment) -> Any:
if self.path_context_obj.endswith("Task.Token"):
task_token = env.context_object_manager.update_task_token()
return env.callback_pool_manager.add(task_token)
value = JSONPathUtils.extract_json(
self.path_context_obj, env.context_object_manager.context_object
)
env.callback_pool_manager.add(task_token)
value = task_token
else:
value = JSONPathUtils.extract_json(
self.path_context_obj, env.context_object_manager.context_object
)
return value
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
class StateTaskService(StateTask, abc.ABC):
resource: ServiceResource

def _get_resource_type(self) -> str:
def _get_sfn_resource(self) -> str:
return self.resource.api_action

def _get_sfn_resource_type(self) -> str:
return self.resource.service_name

def _eval_parameters(self, env: Environment) -> dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
from botocore.exceptions import ClientError

from localstack.aws.api.stepfunctions import (
HistoryEventExecutionDataDetails,
HistoryEventType,
TaskFailedEventDetails,
TaskScheduledEventDetails,
TaskStartedEventDetails,
TaskSucceededEventDetails,
)
from localstack.aws.api.stepfunctions import HistoryEventType, TaskFailedEventDetails
from localstack.aws.protocol.service_router import get_service_catalog
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
FailureEvent,
Expand All @@ -23,63 +16,38 @@
)
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
from localstack.utils.aws import aws_stack
from localstack.utils.common import camel_to_snake_case


class StateTaskServiceAwsSdk(StateTaskServiceCallback):
def _get_resource_type(self) -> str:
_API_NAMES: dict[str, str] = {"sfn": "stepfunctions"}
_SFN_TO_BOTO_PARAM_NORMALISERS = {
"stepfunctions": {"send_task_success": {"Output": "output", "TaskToken": "taskToken"}}
}

def _get_sfn_resource_type(self) -> str:
return f"{self.resource.service_name}:{self.resource.api_name}"

def _eval_service_task(self, env: Environment) -> None:
api_name = self.resource.api_name
api_action = camel_to_snake_case(self.resource.api_action)
def _normalise_api_name(self, api_name: str) -> str:
return self._API_NAMES.get(api_name, api_name)

parameters = self._eval_parameters(env=env)

# Simulate scheduled-start workflow.
parameters_str = to_json_str(parameters)
env.event_history.add_event(
hist_type_event=HistoryEventType.TaskScheduled,
event_detail=EventDetails(
taskScheduledEventDetails=TaskScheduledEventDetails(
resourceType=self._get_resource_type(),
resource=self.resource.api_action,
region=self.resource.region,
parameters=parameters_str,
)
),
)
env.event_history.add_event(
hist_type_event=HistoryEventType.TaskStarted,
event_detail=EventDetails(
taskStartedEventDetails=TaskStartedEventDetails(
resourceType=self._get_resource_type(),
resource=self.resource.api_action,
)
),
)

api_client = aws_stack.create_external_boto_client(service_name=api_name)
def _boto_normalise_parameters(self, api_name: str, api_action: str, parameters: dict) -> None:
api_normalisers = self._SFN_TO_BOTO_PARAM_NORMALISERS.get(api_name, None)
if not api_normalisers:
return

response = getattr(api_client, api_action)(**parameters) or dict()
if response:
response.pop("ResponseMetadata", None)
action_normalisers = api_normalisers.get(api_action, None)
if not action_normalisers:
return None

env.stack.append(response)

env.event_history.add_event(
hist_type_event=HistoryEventType.TaskSucceeded,
event_detail=EventDetails(
taskSucceededEventDetails=TaskSucceededEventDetails(
resourceType=self._get_resource_type(),
resource=self.resource.api_action,
output=to_json_str(response),
outputDetails=HistoryEventExecutionDataDetails(truncated=False),
)
),
)
parameter_keys = list(parameters.keys())
for parameter_key in parameter_keys:
norm_parameter_key = action_normalisers.get(parameter_key, None)
if norm_parameter_key:
tmp = parameters[parameter_key]
del parameters[parameter_key]
parameters[norm_parameter_key] = tmp

@staticmethod
def _normalise_service_name(service_name: str) -> str:
Expand All @@ -96,8 +64,8 @@ def _get_task_failure_event(self, error: str, cause: str) -> FailureEvent:
event_type=HistoryEventType.TaskFailed,
event_details=EventDetails(
taskFailedEventDetails=TaskFailedEventDetails(
resourceType=self._get_resource_type(),
resource=self.resource.api_action,
resource=self._get_sfn_resource(),
resourceType=self._get_sfn_resource_type(),
error=error,
cause=cause,
)
Expand Down Expand Up @@ -127,3 +95,20 @@ def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
error=error, cause=str(ex) # TODO: update cause decoration.
)
return failure_event

def _eval_service_task(self, env: Environment, parameters: dict) -> None:
api_name = self.resource.api_name
api_name = self._normalise_api_name(api_name)
api_action = camel_to_snake_case(self.resource.api_action)

self._boto_normalise_parameters(
api_name=api_name, api_action=api_action, parameters=parameters
)

api_client = aws_stack.create_external_boto_client(service_name=api_name)

response = getattr(api_client, api_action)(**parameters) or dict()
if response:
response.pop("ResponseMetadata", None)

env.stack.append(response)
Original file line number Diff line number Diff line change
@@ -1,19 +1,105 @@
import json
from abc import abstractmethod

from localstack.aws.api.stepfunctions import (
HistoryEventExecutionDataDetails,
HistoryEventType,
TaskScheduledEventDetails,
TaskStartedEventDetails,
TaskSubmittedEventDetails,
TaskSucceededEventDetails,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceCondition,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service import (
StateTaskService,
)
from localstack.services.stepfunctions.asl.eval.callback.callback import CallbackOutcomeSuccess
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str


class StateTaskServiceCallback(StateTaskService):
def _get_sfn_resource(self) -> str:
resource = super()._get_sfn_resource()
if self.resource.condition is not None:
resource += f".{self.resource.condition}"
return resource

@abstractmethod
def _eval_service_task(self, env: Environment):
def _eval_service_task(self, env: Environment, parameters: dict):
...

def _wait_for_task_token(self, env: Environment) -> None: # noqa
callback_id = env.context_object_manager.context_object["Task"]["Token"]
callback_endpoint = env.callback_pool_manager.get(callback_id)
outcome = callback_endpoint.wait() # TODO: implement timeout.

if isinstance(outcome, CallbackOutcomeSuccess):
outcome_output = json.loads(outcome.output)
env.stack.append(outcome_output)
else:
raise NotImplementedError(f"Unsupported Callbackoutcome type '{type(outcome)}'.")

def _is_condition(self):
return self.resource.condition is not None

def _eval_execution(self, env: Environment) -> None:
self._eval_service_task(env=env)
if self.resource.condition is not None:
callback_id = env.context_object_manager.context_object["Task"]["Token"]
callback_endpoint = env.callback_pool_manager.get(callback_id)
callback_endpoint.wait() # TODO: implement timeout.
parameters = self._eval_parameters(env=env)
parameters_str = to_json_str(parameters)

env.event_history.add_event(
hist_type_event=HistoryEventType.TaskScheduled,
event_detail=EventDetails(
taskScheduledEventDetails=TaskScheduledEventDetails(
resource=self._get_sfn_resource(),
resourceType=self._get_sfn_resource_type(),
region=self.resource.region,
parameters=parameters_str,
)
),
)
env.event_history.add_event(
hist_type_event=HistoryEventType.TaskStarted,
event_detail=EventDetails(
taskStartedEventDetails=TaskStartedEventDetails(
resource=self._get_sfn_resource(), resourceType=self._get_sfn_resource_type()
)
),
)

self._eval_service_task(env=env, parameters=parameters)

if self._is_condition():
output = env.stack.pop()
env.event_history.add_event(
hist_type_event=HistoryEventType.TaskSubmitted,
event_detail=EventDetails(
taskSubmittedEventDetails=TaskSubmittedEventDetails(
resource=self._get_sfn_resource(),
resourceType=self._get_sfn_resource_type(),
output=to_json_str(output),
outputDetails=HistoryEventExecutionDataDetails(truncated=False),
)
),
)
match self.resource.condition:
case ResourceCondition.WaitForTaskToken:
self._wait_for_task_token(env=env)
case unsupported:
raise NotImplementedError(f"Unsupported callback type '{unsupported}'.")

output = env.stack[-1]
env.event_history.add_event(
hist_type_event=HistoryEventType.TaskSucceeded,
event_detail=EventDetails(
taskSucceededEventDetails=TaskSucceededEventDetails(
resource=self._get_sfn_resource(),
resourceType=self._get_sfn_resource_type(),
output=to_json_str(output),
outputDetails=HistoryEventExecutionDataDetails(truncated=False),
)
),
)
Loading
0