8000 (torch/elastic) fix scale down bug caused by calling rdzv_handler.shutdown() on premature agent failures by kiukchung · Pull Request #67749 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

(torch/elastic) fix scale down bug caused by calling rdzv_handler.shutdown() on premature agent failures #67749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions test/distributed/elastic/agent/server/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,33 +529,33 @@ def set_timeout(self, timeout):
self.assertEquals(expected_info.serialize(), store.value)
store_mock.assert_called_once()

def test_get_agent_status_event(self):
def test_get_event(self):
spec = self._get_worker_spec(max_restarts=1)
agent = TestAgent(spec)
actual_event = agent.get_agent_status_event(state=WorkerState.SUCCEEDED)
self.assertEqual("AGENT", actual_event.source)
self.assertEqual("static", actual_event.metadata["rdzv_backend"])
self.assertEqual(WorkerState.SUCCEEDED.value, actual_event.metadata["state"])
self.assertEqual(spec.role, actual_event.metadata["role"])
event = agent.get_event_succeeded()
self.assertEqual("AGENT", event.source)
self.assertEqual("static", event.metadata["rdzv_backend"])
self.assertEqual("SUCCEEDED", event.metadata["state"])
self.assertEqual(spec.role, event.metadata["role"])

def test_get_worker_status_event(self):
spec = self._get_worker_spec(max_restarts=4)
agent = TestAgent(spec)
agent._remaining_restarts = spec.max_restarts - 2
actual_event = agent._construct_event(
state=WorkerState.SUCCEEDED.value,
state="SUCCEEDED",
source="WORKER",
worker=agent._worker_group.workers[0],
)
self.assertEqual("WORKER", actual_event.source)
self.assertEqual("static", actual_event.metadata["rdzv_backend"])
self.assertEqual(WorkerState.SUCCEEDED.value, actual_event.metadata["state"])
self.assertEqual("SUCCEEDED", actual_event.metadata["state"])
self.assertEqual(spec.role, actual_event.metadata["role"])
self.assertEqual(2, actual_event.metadata["agent_restarts"])

@patch("torch.distributed.elastic.agent.server.api.put_metric")
@patch.object(TestAgent, "_invoke_run")
def test_agent_process_signal_exception(self, invoke_run, put_metric_mock):
def test_agent_process_signal_exception(self, invoke_run, _):
spec = self._get_worker_spec(max_restarts=0)
agent = TestAgent(spec)
invoke_run.side_effect = SignalException(
Expand Down
77 changes: 56 additions & 21 deletions test/distributed/launcher/api_test.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,29 @@
import multiprocessing as mp
import os
import shutil
import signal
import sys
import tempfile
import time
import unittest
import uuid
from contextlib import closing
from typing import Optional, Any, Dict
from typing import Any, Dict, Optional
from unittest import mock
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import torch
import torch.distributed as dist
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
from torch.distributed.elastic.multiprocessing.api import SignalException
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.utils import get_socket_with_port
from torch.distributed.launcher.api import (
LaunchConfig,
elastic_launch,
_get_entrypoint_name,
elastic_launch,
launch_agent,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
Expand Down Expand Up @@ -60,10 +63,21 @@ def _dist_sum(wait=0):
return t.item()


ELASTIC_AGENT_RUN = "torch.distributed.launcher.api.LocalElasticAgent.run"
EVENTS_RECORD = "torch.distributed.launcher.api.events.record"
GET_RDZV_HANDLER = (
"torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler"
)


class MockException(Exception):
pass


def short_hash():
return str(uuid.uuid4()).split("-")[0]


class ElasticLaunchTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -128,9 +142,7 @@ def check_works_ran(self, world_size: int):
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
)

@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_script_python(self):
nnodes = 1
nproc_per_node = 4
Expand All @@ -145,9 +157,7 @@ def test_launch_script_python(self):
world_size = nnodes * nproc_per_node
self.check_works_ran(world_size)

@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_script_python_local_rank_transfer(self):
nnodes = 1
nproc_per_node = 4
Expand All @@ -162,9 +172,7 @@ def test_launch_script_python_local_rank_transfer(self):
world_size = nnodes * nproc_per_node
self.check_works_ran(world_size)

@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_script_bash(self):
nnodes = 1
nproc_per_node = 4
Expand All @@ -177,9 +185,7 @@ def test_launch_script_bash(self):
world_size = nnodes * nproc_per_node
self.check_works_ran(world_size)

@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_function(self):
nnodes = 1
nproc_per_node = 4
Expand All @@ -193,9 +199,7 @@ def test_launch_function(self):
actual_res = sorted(value for value in res.values())
self.assertEqual(expected_res, actual_res)

@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_dist_sum_with_static_rdzv(self):
nnodes = 1
nproc_per_node = 4
Expand Down Expand Up @@ -224,9 +228,7 @@ def test_launch_dist_sum_with_static_rdzv(self):
actual_res = sorted(value for value in res.values())
self.assertEqual(expected_res, actual_res)

@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_elastic(self):
nproc_per_node = 4

Expand Down Expand Up @@ -338,3 +340,36 @@ def test_get_entrypoint_name(self):
_get_entrypoint_name(sys.executable, ["-u", "test_script.py"]),
)
self.assertEqual("", _get_entrypoint_name(None, []))

@patch(ELASTIC_AGENT_RUN)
@patch(GET_RDZV_HANDLER)
def test_rdzv_handler_shutdown_on_agent_signal(self, mock_get_rdzv, mock_agent_run):
config = self.get_test_launch_config(min_nodes=1, max_nodes=1, nproc_per_node=1)

for sigval in [signal.SIGTERM, signal.SIGINT]:
with patch(EVENTS_RECORD) as record_event_mock:
rdzv_handler_mock = MagicMock()
rdzv_handler_mock.get_run_id.return_value = short_hash()
mock_get_rdzv.return_value = rdzv_handler_mock

mock_agent_run.side_effect = SignalException("test", sigval)
with self.assertRaises(SignalException):
launch_agent(config, simple_rank_scale, [])
rdzv_handler_mock.shutdown.assert_not_called()
record_event_mock.assert_called_once()

@patch(ELASTIC_AGENT_RUN)
@patch(GET_RDZV_HANDLER)
def test_rdzv_handler_shutdown_on_agent_error(self, mock_get_rdzv, mock_agent_run):
config = self.get_test_launch_config(min_nodes=1, max_nodes=1, nproc_per_node=1)

with patch(EVENTS_RECORD) as record_event_mock:
rdzv_handler_mock = MagicMock()
rdzv_handler_mock.get_run_id.return_value = short_hash()
mock_get_rdzv.return_value = rdzv_handler_mock

mock_agent_run.side_effect = RuntimeError("any other exception")
with self.assertRaises(RuntimeError):
launch_agent(config, simple_rank_scale, [])
rdzv_handler_mock.shutdown.assert_called_once()
record_event_mock.assert_called_once()
15 changes: 11 additions & 4 deletions torch/distributed/elastic/agent/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from torch.distributed.elastic.metrics import prof, put_metric
from torch.distributed.elastic.multiprocessing import (
ProcessFailure,
Std,
SignalException,
Std,
)
from torch.distributed.elastic.utils.logging import get_logger

Expand Down Expand Up @@ -722,10 +722,17 @@ def run(self, role: str = DEFAULT_ROLE) -> RunResult:
# record the execution time in case there were any exceptions during run.
self._total_execution_time = int(time.monotonic() - start_time)

def get_agent_status_event(self, state: WorkerState) -> Event:
raw_error = traceback.format_exc() if state == WorkerState.FAILED else None
def get_event_failed(self) -> Event:
return self._construct_event(
state="FAILED",
source=EventSource.AGENT,
raw_error=traceback.format_exc(),
)

def get_event_succeeded(self) -> Event:
return self._construct_event(
state.value, EventSource.AGENT, raw_error=raw_error
state="SUCCEEDED",
source=EventSource.AGENT,
)

def _record_worker_events(self, result: RunResult) -> None:
Expand Down
87 changes: 39 additions & 48 deletions torch/distributed/launcher/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import sys
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Union, cast, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic import events, metrics
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState
from torch.distributed.elastic.agent.server.api import WorkerSpec
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from torch.distributed.elastic.multiprocessing import Std
from torch.distributed.elastic.multiprocessing import SignalException, Std
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
Expand Down Expand Up @@ -131,19 +131,6 @@ def __call__(self, *args):
return launch_agent(self._config, self._entrypoint, list(args))


def _construct_event(config: LaunchConfig) -> events.Event:
metadata = {
"rdzv_backend": config.rdzv_backend,
"run_id": config.run_id,
"role": config.role,
}
return events.Event(
name="torch.distributed.elastic.launch_agent",
source=events.EventSource.AGENT,
metadata=cast(Dict[str, events.EventMetadataValue], metadata),
)


def _get_entrypoint_name(
entrypoint: Union[Callable, str, None], args: List[Any]
) -> str:
Expand Down Expand Up @@ -185,16 +172,14 @@ def _get_addr_and_port(
return (master_addr, master_port)


# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# torch.distributed.elastic.multiprocessing.errors.record.
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
if not config.run_id:
run_id = str(uuid.uuid4().int)
logger.warning(f"config has no run_id, generate a new one: {run_id}")
logger.warning(f"config has no run_id, generated a random run_id: {run_id}")
config.run_id = run_id

entrypoint_name = _get_entrypoint_name(entrypoint, args)
Expand Down Expand Up @@ -224,33 +209,34 @@ def launch_agent(
**config.rdzv_configs,
)

agent = None
rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
try:
spec = WorkerSpec(
role=config.role,
local_world_size=config.nproc_per_node,
entrypoint=entrypoint,
args=tuple(args),
rdzv_handler=rdzv_handler,
max_restarts=config.max_restarts,
monitor_interval=config.monitor_interval,
redirects=config.redirects,
tee=config.tee,
master_addr=master_addr,
master_port=master_port,
)

cfg = metrics.MetricsConfig(config.metrics_cfg) if config.metrics_cfg else None
metrics.initialize_metrics(cfg)
spec = WorkerSpec(
role=config.role,
local_world_size=config.nproc_per_node,
entrypoint=entrypoint,
args=tuple(args),
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
max_restarts=config.max_restarts,
monitor_interval=config.monitor_interval,
redirects=config.redirects,
tee=config.tee,
master_addr=master_addr,
master_port=master_port,
)

agent = LocalElasticAgent(
spec=spec, start_method=config.start_method, log_dir=config.log_dir
)
agent = LocalElasticAgent(
spec=spec, start_method=config.start_method, log_dir=config.log_dir
)

shutdown_rdzv = True
try:
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))

result = agent.run()
events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED))
# records that agent.run() has succeeded NOT that workers have succeeded
events.record(agent.get_event_succeeded())

if result.is_failed():
# ChildFailedError is treated specially by @record
# if the error files for the failed children exist
Expand All @@ -260,15 +246,20 @@ def launch_agent(
name=entrypoint_name,
failures=result.failures,
)
else:
return result.return_values

return result.return_values
except ChildFailedError:
raise
except SignalException:
# when the agent dies with a signal do NOT shutdown the rdzv_handler
# since this closes the rendezvous on this rdzv_id permanently and
# prevents any additional scaling events
shutdown_rdzv = False
events.record(agent.get_event_failed())
raise
except Exception:
if agent:
events.record(agent.get_agent_status_event(WorkerState.FAILED))
else:
events.record(_construct_event(config))
events.record(agent.get_event_failed())
raise
finally:
rdzv_handler.shutdown()
if shutdown_rdzv:
spec.rdzv_handler.shutdown()
0