8000 Step Functions: Improve Nested Map Run Stability by MEPalma · Pull Request #12343 · localstack/localstack · GitHub
[go: up one dir, main page]

Skip to content

Step Functions: Improve Nested Map Run Stability #12343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import abc
import json
import threading
from typing import Any, Final, Optional

from localstack.aws.api.stepfunctions import (
Expand Down Expand Up @@ -36,9 +35,6 @@
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.iteration.itemprocessor.processor_config import (
ProcessorConfig,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.iteration.iteration_worker import (
IterationWorker,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.iteration.job import (
JobClosed,
JobPool,
Expand All @@ -56,6 +52,7 @@
class DistributedIterationComponentEvalInput(InlineIterationComponentEvalInput):
item_reader: Final[Optional[ItemReader]]
label: Final[Optional[str]]
map_run_record: Final[MapRunRecord]

def __init__(
self,
Expand All @@ -68,6 +65,7 @@ def __init__(
tolerated_failure_count: int,
tolerated_failure_percentage: float,
label: Optional[str],
map_run_record: MapRunRecord,
):
super().__init__(
state_name=state_name,
Expand All @@ -80,14 +78,10 @@ def __init__(
self.tolerated_failure_count = tolerated_failure_count
self.tolerated_failure_percentage = tolerated_failure_percentage
self.label = label
self.map_run_record = map_run_record


class DistributedIterationComponent(InlineIterationComponent, abc.ABC):
_eval_input: Optional[DistributedIterationComponentEvalInput]
_mutex: Final[threading.Lock]
_map_run_record: Optional[MapRunRecord]
_workers: list[IterationWorker]

def __init__(
self,
query_language: QueryLanguage,
Expand All @@ -103,89 +97,59 @@ def __init__(
comment=comment,
processor_config=processor_config,
)
self._mutex = threading.Lock()
self._map_run_record = None
self._workers = list()

@abc.abstractmethod
def _create_worker(self, env: Environment) -> IterationWorker: ...

def _launch_worker(self, env: Environment) -> IterationWorker:
worker = super()._launch_worker(env=env)
self._workers.append(worker)
return worker

def _set_active_workers(self, workers_number: int, env: Environment) -> None:
with self._mutex:
current_workers_number = len(self._workers)
workers_diff = workers_number - current_workers_number
if workers_diff > 0:
for _ in range(workers_diff):
self._launch_worker(env=env)
elif workers_diff < 0:
deletion_workers = list(self._workers)[workers_diff:]
for worker in deletion_workers:
worker.sig_stop()
self._workers.remove(worker)

def _map_run(self, env: Environment) -> None:

def _map_run(
self, env: Environment, eval_input: DistributedIterationComponentEvalInput
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice conversion into statelessness using eval_input as parameter 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs: Would it make sense to add a comment about the statelessness here for our future selves?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll leave a comment in the top level class for iterations

) -> None:
input_items: list[json] = env.stack.pop()

input_item_program: Final[Program] = self._get_iteration_program()
self._job_pool = JobPool(job_program=input_item_program, job_inputs=input_items)
job_pool = JobPool(job_program=input_item_program, job_inputs=input_items)

# TODO: add watch on map_run_record update event and adjust the number of running workers accordingly.
max_concurrency = self._map_run_record.max_concurrency
max_concurrency = eval_input.map_run_record.max_concurrency
workers_number = (
len(input_items)
if max_concurrency == DEFAULT_MAX_CONCURRENCY_VALUE
else max_concurrency
)
self._set_active_workers(workers_number=workers_number, env=env)
for _ in range(workers_number):
self._launch_worker(env=env, eval_input=eval_input, job_pool=job_pool)

self._job_pool.await_jobs()
job_pool.await_jobs()

worker_exception: Optional[Exception] = self._job_pool.get_worker_exception()
worker_exception: Optional[Exception] = job_pool.get_worker_exception()
if worker_exception is not None:
raise worker_exception

closed_jobs: list[JobClosed] = self._job_pool.get_closed_jobs()
closed_jobs: list[JobClosed] = job_pool.get_closed_jobs()
outputs: list[Any] = [closed_job.job_output for closed_job in closed_jobs]

env.stack.append(outputs)

def _eval_body(self, env: Environment) -> None:
self._eval_input = env.stack.pop()

self._map_run_record = MapRunRecord(
state_machine_arn=env.states.context_object.context_object_data["StateMachine"]["Id"],
execution_arn=env.states.context_object.context_object_data["Execution"]["Id"],
max_concurrency=self._eval_input.max_concurrency,
tolerated_failure_count=self._eval_input.tolerated_failure_count,
tolerated_failure_percentage=self._eval_input.tolerated_failure_percentage,
label=self._eval_input.label,
)
env.map_run_record_pool_manager.add(self._map_run_record)
eval_input: DistributedIterationComponentEvalInput = env.stack.pop()
map_run_record = eval_input.map_run_record

env.event_manager.add_event(
context=env.event_history_context,
event_type=HistoryEventType.MapRunStarted,
event_details=EventDetails(
mapRunStartedEventDetails=MapRunStartedEventDetails(
mapRunArn=self._map_run_record.map_run_arn
mapRunArn=map_run_record.map_run_arn
)
),
)

parent_event_manager = env.event_manager
try:
if self._eval_input.item_reader:
self._eval_input.item_reader.eval(env=env)
if eval_input.item_reader:
eval_input.item_reader.eval(env=env)
else:
env.stack.append(self._eval_input.input_items)
env.stack.append(eval_input.input_items)

env.event_manager = EventManager()
self._map_run(env=env)
self._map_run(env=env, eval_input=eval_input)

except FailureEventException as failure_event_ex:
map_run_fail_event_detail = MapRunFailedEventDetails()
Expand All @@ -204,7 +168,7 @@ def _eval_body(self, env: Environment) -> None:
event_type=HistoryEventType.MapRunFailed,
event_details=EventDetails(mapRunFailedEventDetails=map_run_fail_event_detail),
)
self._map_run_record.set_stop(status=MapRunStatus.FAILED)
map_run_record.set_stop(status=MapRunStatus.FAILED)
raise failure_event_ex

except Exception as ex:
Expand All @@ -214,17 +178,13 @@ def _eval_body(self, env: Environment) -> None:
event_type=HistoryEventType.MapRunFailed,
event_details=EventDetails(mapRunFailedEventDetails=MapRunFailedEventDetails()),
)
self._map_run_record.set_stop(status=MapRunStatus.FAILED)
map_run_record.set_stop(status=MapRunStatus.FAILED)
raise ex
finally:
env.event_manager = parent_event_manager
self._eval_input = None
self._workers.clear()

# TODO: review workflow of program stops and maprunstops
# program_state = env.program_state()
# if isinstance(program_state, ProgramSucceeded)
# TODO: review workflow of program stops and map run stops
env.event_manager.add_event(
context=env.event_history_context, event_type=HistoryEventType.MapRunSucceeded
)
self._map_run_record.set_stop(status=MapRunStatus.SUCCEEDED)
map_run_record.set_stop(status=MapRunStatus.SUCCEEDED)
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def __init__(

class InlineIterationComponent(IterationComponent, abc.ABC):
_processor_config: Final[ProcessorConfig]
_eval_input: Optional[InlineIterationComponentEvalInput]
_job_pool: Optional[JobPool]

def __init__(
self,
Expand All @@ -73,45 +71,45 @@ def __init__(
query_language=query_language, start_at=start_at, states=states, comment=comment
)
self._processor_config = processor_config
self._eval_input = None
self._job_pool = None

@abc.abstractmethod
def _create_worker(self, env: Environment) -> IterationWorker: ...

def _launch_worker(self, env: Environment) -> IterationWorker:
worker = self._create_worker(env=env)
def _create_worker(
self, env: Environment, eval_input: InlineIterationComponentEvalInput, job_pool: JobPool
) -> IterationWorker: ...

def _launch_worker(
self, env: Environment, eval_input: InlineIterationComponentEvalInput, job_pool: JobPool
) -> IterationWorker:
worker = self._create_worker(env=env, eval_input=eval_input, job_pool=job_pool)
worker_thread = threading.Thread(target=worker.eval, daemon=True)
TMP_THREADS.append(worker_thread)
worker_thread.start()
return worker

def _eval_body(self, env: Environment) -> None:
self._eval_input = env.stack.pop()
eval_input = env.stack.pop()

max_concurrency: int = self._eval_input.max_concurrency
input_items: list[json] = self._eval_input.input_items
max_concurrency: int = eval_input.max_concurrency
input_items: list[json] = eval_input.input_items

input_item_program: Final[Program] = self._get_iteration_program()
self._job_pool = JobPool(
job_program=input_item_program, job_inputs=self._eval_input.input_items
)
job_pool = JobPool(job_program=input_item_program, job_inputs=eval_input.input_items)

number_of_workers = (
len(input_items)
if max_concurrency == DEFAULT_MAX_CONCURRENCY_VALUE
else max_concurrency
)
for _ in range(number_of_workers):
self._launch_worker(env=env)
self._launch_worker(env=env, eval_input=eval_input, job_pool=job_pool)

self._job_pool.await_jobs()
job_pool.await_jobs()

worker_exception: Optional[Exception] = self._job_pool.get_worker_exception()
worker_exception: Optional[Exception] = job_pool.get_worker_exception()
if worker_exception is not None:
raise worker_exception

closed_jobs: list[JobClosed] = self._job_pool.get_closed_jobs()
closed_jobs: list[JobClosed] = job_pool.get_closed_jobs()
outputs: list[Any] = [closed_job.job_output for closed_job in closed_jobs]

env.stack.append(outputs)
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import Optional

from localstack.services.stepfunctions.asl.component.common.comment import Comment
from localstack.services.stepfunctions.asl.component.common.flow.start_at import StartAt
from localstack.services.stepfunctions.asl.component.common.query_language import QueryLanguage
Expand All @@ -16,6 +14,9 @@
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.iteration.itemprocessor.processor_config import (
ProcessorConfig,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.iteration.job import (
JobPool,
)
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.parse.typed_props import TypedProps

Expand All @@ -25,8 +26,6 @@ class DistributedItemProcessorEvalInput(DistributedIterationComponentEvalInput):


class DistributedItemProcessor(DistributedIterationComponent):
_eval_input: Optional[DistributedItemProcessorEvalInput]

@classmethod
def from_props(cls, props: TypedProps) -> DistributedItemProcessor:
item_processor = cls(
Expand All @@ -44,13 +43,15 @@ def from_props(cls, props: TypedProps) -> DistributedItemProcessor:
)
return item_processor

def _create_worker(self, env: Environment) -> DistributedItemProcessorWorker:
def _create_worker(
self, env: Environment, eval_input: DistributedItemProcessorEvalInput, job_pool: JobPool
) -> DistributedItemProcessorWorker:
return DistributedItemProcessorWorker(
work_name=self._eval_input.state_name,
job_pool=self._job_pool,
work_name=eval_input.state_name,
job_pool=job_pool,
env=env,
item_reader=self._eval_input.item_reader,
parameters=self._eval_input.parameters,
item_selector=self._eval_input.item_selector,
map_run_record=self._map_run_record,
item_reader=eval_input.item_reader,
parameters=eval_input.parameters,
item_selector=eval_input.item_selector,
map_run_record=eval_input.map_run_record,
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
from typing import Optional

from localstack.services.stepfunctions.asl.component.common.comment import Comment
from localstack.services.stepfunctions.asl.component.common.flow.start_at import StartAt
Expand All @@ -17,6 +16,9 @@
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.iteration.itemprocessor.processor_config import (
ProcessorConfig,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.iteration.job import (
JobPool,
)
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.parse.typed_props import TypedProps

Expand All @@ -28,8 +30,6 @@ class InlineItemProcessorEvalInput(InlineIterationComponentEvalInput):


class InlineItemProcessor(InlineIterationComponent):
_eval_input: Optional[InlineItemProcessorEvalInput]

@classmethod
def from_props(cls, props: TypedProps) -> InlineItemProcessor:
if not props.get(States):
Expand All @@ -45,11 +45,13 @@ def from_props(cls, props: TypedProps) -> InlineItemProcessor:
)
return item_processor

def _create_worker(self, env: Environment) -> InlineItemProcessorWorker:
def _create_worker(
self, env: Environment, eval_input: InlineItemProcessorEvalInput, job_pool: JobPool
) -> InlineItemProcessorWorker:
return InlineItemProcessorWorker(
work_name=self._eval_input.state_name,
job_pool=self._job_pool,
work_name=eval_input.state_name,
job_pool=job_pool,
env=env,
item_selector=self._eval_input.item_selector,
parameters=self._eval_input.parameters,
item_selector=eval_input.item_selector,
parameters=eval_input.parameters,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@


class IterationComponent(EvalComponent, abc.ABC):
# Ensure no member variables are used to keep track of the state of
# iteration components: the evaluation must be stateless as for all
# EvalComponents to ensure they can be reused or used concurrently.
_query_language: Final[QueryLanguage]
_start_at: Final[StartAt]
_states: Final[States]
Expand Down
Loading
Loading
0