8000 (torch/elastic) fix scale down bug caused by calling rdzv_handler.shu… · pytorch/pytorch@e3dc0d9 · GitHub
[go: up one dir, main page]

Skip to content

Commit e3dc0d9

Browse files
Kiuk Chungfacebook-github-bot
authored andcommitted
(torch/elastic) fix scale down bug caused by calling rdzv_handler.shutdown() on premature agent failures (#67749)
Summary: Pull Request resolved: #67749 Fixes: #67742 Test Plan: Added unittests. Validated manually: ``` # start agent 0 $ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 --monitor_interval 1 test.py # start agent 1 torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 --monitor_interval 1 test.py # kill agent 0 CTRL+C (SIGINT) or kill -15 (SIGTERM) # restart it torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 --monitor_interval 1 test.py ``` Reviewed By: cbalioglu Differential Revision: D32129005 fbshipit-source-id: 4e695d0b3397951d375ecee321add5faf0cfa3ea
1 parent 5a48868 commit e3dc0d9

File tree

4 files changed

+115
-82
lines changed

4 files changed

+115
-82
lines changed

test/distributed/elastic/agent/server/test/api_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -529,33 +529,33 @@ def set_timeout(self, timeout):
529529
self.assertEquals(expected_info.serialize(), store.value)
530530
store_mock.assert_called_once()
531531

532-
def test_get_agent_status_event(self):
532+
def test_get_event(self):
533533
spec = self._get_worker_spec(max_restarts=1)
534534
agent = TestAgent(spec)
535-
actual_event = agent.get_agent_status_event(state=WorkerState.SUCCEEDED)
536-
self.assertEqual("AGENT", actual_event.source)
537-
self.assertEqual("static", actual_event.metadata["rdzv_backend"])
538-
self.assertEqual(WorkerState.SUCCEEDED.value, actual_event.metadata["state"])
539-
self.assertEqual(spec.role, actual_event.metadata["role"])
535+
event = agent.get_event_succeeded()
536+
self.assertEqual("AGENT", event.source)
537+
self.assertEqual("static", event.metadata["rdzv_backend"])
538+
self.assertEqual("SUCCEEDED", event.metadata["state"])
539+
self.assertEqual(spec.role, event.metadata["role"])
540540

541541
def test_get_worker_status_event(self):
542542
spec = self._get_worker_spec(max_restarts=4)
543543
agent = TestAgent(spec)
544544
agent._remaining_restarts = spec.max_restarts - 2
545545
actual_event = agent._construct_event(
546-
state=WorkerState.SUCCEEDED.value,
546+
state="SUCCEEDED",
547547
source="WORKER",
548548
worker=agent._worker_group.workers[0],
549549
)
550550
self.assertEqual("WORKER", actual_event.source)
551551
self.assertEqual("static", actual_event.metadata["rdzv_backend"])
552-
self.assertEqual(WorkerState.SUCCEEDED.value, actual_event.metadata["state"])
552+
self.assertEqual("SUCCEEDED", actual_event.metadata["state"])
553553
self.assertEqual(spec.role, actual_event.metadata["role"])
554554
self.assertEqual(2, actual_event.metadata["agent_restarts"])
555555

556556
@patch("torch.distributed.elastic.agent.server.api.put_metric")
557557
@patch.object(TestAgent, "_invoke_run")
558-
def test_agent_process_signal_exception(self, invoke_run, put_metric_mock):
558+
def test_agent_process_signal_exception(self, invoke_run, _):
559559
spec = self._get_worker_spec(max_restarts=0)
560560
agent = TestAgent(spec)
561561
invoke_run.side_effect = SignalException(

test/distributed/launcher/api_test.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,29 @@
99
import multiprocessing as mp
1010
import os
1111
import shutil
12+
import signal
1213
import sys
1314
import tempfile
1415
import time
1516
import unittest
1617
import uuid
1718
from contextlib import closing
18-
from typing import Optional, Any, Dict
19+
from typing import Any, Dict, Optional
1920
from unittest import mock
20-
from unittest.mock import Mock, patch
21+
from unittest.mock import MagicMock, Mock, patch
2122

2223
import torch
2324
import torch.distributed as dist
2425
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
26+
from torch.distributed.elastic.multiprocessing.api import SignalException
2527
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
2628
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
2729
from torch.distributed.elastic.utils import get_socket_with_port
2830
from torch.distributed.launcher.api import (
2931
LaunchConfig,
30-
elastic_launch,
3132
_get_entrypoint_name,
33+
elastic_launch,
34+
launch_agent,
3235
)
3336
from torch.testing._internal.common_utils import (
3437
TEST_WITH_DEV_DBG_ASAN,
@@ -60,10 +63,21 @@ def _dist_sum(wait=0):
6063
return t.item()
6164

6265

66+
ELASTIC_AGENT_RUN = "torch.distributed.launcher.api.LocalElasticAgent.run"
67+
EVENTS_RECORD = "torch.distributed.launcher.api.events.record"
68+
GET_RDZV_HANDLER = (
69+
"torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler"
70+
)
71+
72+
6373
class MockException(Exception):
6474
pass
6575

6676

77+
def short_hash():
78+
return str(uuid.uuid4()).split("-")[0]
79+
80+
6781
class ElasticLaunchTest(unittest.TestCase):
6882
@classmethod
6983
def setUpClass(cls):
@@ -128,9 +142,7 @@ def check_works_ran(self, world_size: int):
128142
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
129143
)
130144

131-
@sandcastle_skip_if(
132-
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
133-
)
145+
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
134146
def test_launch_script_python(self):
135147
nnodes = 1
136148
nproc_per_node = 4
@@ -145,9 +157,7 @@ def test_launch_script_python(self):
145157
world_size = nnodes * nproc_per_node
146158
self.check_works_ran(world_size)
147159

148-
@sandcastle_skip_if(
149-
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
150-
)
160+
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
151161
def test_launch_script_python_local_rank_transfer(self):
152162
nnodes = 1
153163
nproc_per_node = 4
@@ -162,9 +172,7 @@ def test_launch_script_python_local_rank_transfer(self):
162172
world_size = nnodes * nproc_per_node
163173
self.check_works_ran(world_size)
164174

165-
@sandcastle_skip_if(
166-
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
167-
)
175+
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
168176
def test_launch_script_bash(self):
169177
nnodes = 1
170178
nproc_per_node = 4
@@ -177,9 +185,7 @@ def test_launch_script_bash(self):
177185
world_size = nnodes * nproc_per_node
178186
self.check_works_ran(world_size)
179187

180-
@sandcastle_skip_if(
181-
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
182-
)
188+
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
183189
def test_launch_function(self):
184190
nnodes = 1
185191
nproc_per_node = 4
@@ -193,9 +199,7 @@ def test_launch_function(self):
193199
actual_res = sorted(value for value in res.values())
194200
self.assertEqual(expected_res, actual_res)
195201

196-
@sandcastle_skip_if(
197-
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
198-
)
202+
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
199203
def test_launch_dist_sum_with_static_rdzv(self):
200204
nnodes = 1
201205
nproc_per_node = 4
@@ -224,9 +228,7 @@ def test_launch_dist_sum_with_static_rdzv(self):
224228
actual_res = sorted(value for value in res.values())
225229
self.assertEqual(expected_res, actual_res)
226230

227-
@sandcastle_skip_if(
228-
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
229-
)
231+
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
230232
def test_launch_elastic(self):
231233
nproc_per_node = 4
232234

@@ -338,3 +340,36 @@ def test_get_entrypoint_name(self):
338340
_get_entrypoint_name(sys.executable, ["-u", "test_script.py"]),
339341
)
340342
self.assertEqual("", _get_entrypoint_name(None, []))
343+
344+
@patch(ELASTIC_AGENT_RUN)
345+
@patch(GET_RDZV_HANDLER)
346+
def test_rdzv_handler_shutdown_on_agent_signal(self, mock_get_rdzv, mock_agent_run):
347+
config = self.get_test_launch_config(min_nodes=1, max_nodes=1, nproc_per_node=1)
348+
349+
for sigval in [signal.SIGTERM, signal.SIGINT]:
350+
with patch(EVENTS_RECORD) as record_event_mock:
351+
rdzv_handler_mock = MagicMock()
352+
rdzv_handler_mock.get_run_id.return_value = short_hash()
353+
mock_get_rdzv.return_value = rdzv_handler_mock
354+
355+
mock_agent_run.side_effect = SignalException("test", sigval)
356+
with self.assertRaises(SignalException):
357+
launch_agent(config, simple_rank_scale, [])
358+
rdzv_handler_mock.shutdown.assert_not_called()
359+
record_event_mock.assert_called_once()
360+
361+
@patch(ELASTIC_AGENT_RUN)
362+
@patch(GET_RDZV_HANDLER)
363+
def test_rdzv_handler_shutdown_on_agent_error(self, mock_get_rdzv, mock_agent_run):
364+
config = self.get_test_launch_config(min_nodes=1, max_nodes=1, nproc_per_node=1)
365+
366+
with patch(EVENTS_RECORD) as record_event_mock:
367+
rdzv_handler_mock = MagicMock()
368+
rdzv_handler_mock.get_run_id.return_value = short_hash()
369+
mock_get_rdzv.return_value = rdzv_handler_mock
370+
371+
mock_agent_run.side_effect = RuntimeError("any other exception")
372+
with self.assertRaises(RuntimeError):
373+
launch_agent(config, simple_rank_scale, [])
374+
rdzv_handler_mock.shutdown.assert_called_once()
375+
record_event_mock.assert_called_once()

torch/distributed/elastic/agent/server/api.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from torch.distributed.elastic.metrics import prof, put_metric
2828
from torch.distributed.elastic.multiprocessing import (
2929
ProcessFailure,
30-
Std,
3130
SignalException,
31+
Std,
3232
)
3333
from torch.distributed.elastic.utils.logging import get_logger
3434

@@ -722,10 +722,17 @@ def run(self, role: str = DEFAULT_ROLE) -> RunResult:
722722
# record the execution time in case there were any exceptions during run.
723723
self._total_execution_time = int(time.monotonic() - start_time)
724724

725-
def get_agent_status_event(self, state: WorkerState) -> Event:
726-
raw_error = traceback.format_exc() if state == WorkerState.FAILED else None
725+
def get_event_failed(self) -> Event:
726+
return self._construct_event(
727+
state="FAILED",
728+
source=EventSource.AGENT,
729+
raw_error=traceback.format_exc(),
730+
)
731+
732+
def get_event_succeeded(self) -> Event:
727733
return self._construct_event(
728-
state.value, EventSource.AGENT, raw_error=raw_error
734+
state="SUCCEEDED",
735+
source=EventSource.AGENT,
729736
)
730737

731738
def _record_worker_events(self, result: RunResult) -> None:

torch/distributed/launcher/api.py

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import sys
99
import uuid
1010
from dataclasses import dataclass, field
11-
from typing import Any, Callable, Dict, List, Optional, Union, cast, Tuple
11+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1212

1313
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
1414
from torch.distributed.elastic import events, metrics
15-
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState
15+
from torch.distributed.elastic.agent.server.api import WorkerSpec
1616
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
17-
from torch.distributed.elastic.multiprocessing import Std
17+
from torch.distributed.elastic.multiprocessing import SignalException, Std
1818
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
1919
from torch.distributed.elastic.rendezvous import RendezvousParameters
2020
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
@@ -131,19 +131,6 @@ def __call__(self, *args):
131131
return launch_agent(self._config, self._entrypoint, list(args))
132132

133133

134-
def _construct_event(config: LaunchConfig) -> events.Event:
135-
metadata = {
136-
"rdzv_backend": config.rdzv_backend,
137-
"run_id": config.run_id,
138-
"role": config.role,
139-
}
140-
return events.Event(
141-
name="torch.distributed.elastic.launch_agent",
142-
source=events.EventSource.AGENT,
143-
metadata=cast(Dict[str, events.EventMetadataValue], metadata),
144-
)
145-
146-
147134
def _get_entrypoint_name(
148135
entrypoint: Union[Callable, str, None], args: List[Any]
149136
) -> str:
@@ -185,16 +172,14 @@ def _get_addr_and_port(
185172
return (master_addr, master_port)
186173

187174

188-
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
189-
# torch.distributed.elastic.multiprocessing.errors.record.
190175
def launch_agent(
191176
config: LaunchConfig,
192177
entrypoint: Union[Callable, str, None],
193178
args: List[Any],
194179
) -> Dict[int, Any]:
195180
if not config.run_id:
196181
run_id = str(uuid.uuid4().int)
197-
logger.warning(f"config has no run_id, generate a new one: {run_id}")
182+
logger.warning(f"config has no run_id, generated a random run_id: {run_id}")
198183
config.run_id = run_id
199184

200185
entrypoint_name = _get_entrypoint_name(entrypoint, args)
@@ -224,33 +209,34 @@ def launch_agent(
224209
**config.rdzv_configs,
225210
)
226211

227-
agent = None
228-
rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
229212
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
230-
try:
231-
spec = WorkerSpec(
232-
role=config.role,
233-
local_world_size=config.nproc_per_node,
234-
entrypoint=entrypoint,
235-
args=tuple(args),
236-
rdzv_handler=rdzv_handler,
237-
max_restarts=config.max_restarts,
238-
monitor_interval=config.monitor_interval,
239-
redirects=config.redirects,
240-
tee=config.tee,
241-
master_addr=master_addr,
242-
master_port=master_port,
243-
)
244213

245-
cfg = metrics.MetricsConfig(config.metrics_cfg) if config.metrics_cfg else None
246-
metrics.initialize_metrics(cfg)
214+
spec = WorkerSpec(
215+
role=config.role,
216+
local_world_size=config.nproc_per_node,
217+
entrypoint=entrypoint,
218+
args=tuple(args),
219+
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
220+
max_restarts=config.max_restarts,
221+
monitor_interval=config.monitor_interval,
222+
redirects=config.redirects,
223+
tee=config.tee,
224+
master_addr=master_addr,
225+
master_port=master_port,
226+
)
247227

248-
agent = LocalElasticAgent(
249-
spec=spec, start_method=config.start_method, log_dir=config.log_dir
250-
)
228+
agent = LocalElasticAgent(
229+
spec=spec, start_method=config.start_method, log_dir=config.log_dir
230+
)
231+
232+
shutdown_rdzv = True
233+
try:
234+
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
251235

252236
result = agent.run()
253-
events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED))
237+
# records that agent.run() has succeeded NOT that workers have succeeded
238+
events.record(agent.get_event_succeeded())
239+
254240
if result.is_failed():
255241
# ChildFailedError is treated specially by @record
256242
# if the error files for the failed children exist
@@ -260,15 +246,20 @@ def launch_agent(
260246
name=entrypoint_name,
261247
failures=result.failures,
262248
)
263-
else:
264-
return result.return_values
249+
250+
return result.return_values
265251
except ChildFailedError:
266252
raise
253+
except SignalException:
254+
# when the agent dies with a signal do NOT shutdown the rdzv_handler
255+
# since this closes the rendezvous on this rdzv_id permanently and
256+
# prevents any additional scaling events
257+
shutdown_rdzv = False
258+
events.record(agent.get_event_failed())
259+
raise
267260
except Exception:
268-
if agent:
269-
events.record(agent.get_agent_status_event(WorkerState.FAILED))
270-
else:
271-
events.record(_construct_event(config))
261+
events.record(agent.get_event_failed())
272262
raise
273263
finally:
274-
rdzv_handler.shutdown()
264+
if shutdown_rdzv:
265+
spec.rdzv_handler.shutdown()

0 commit comments

Comments
 (0)
0