8000 elastic: do not shutdown rendezvous on leaving workers (#152525) · pytorch/pytorch@8739a8c · GitHub
[go: up one dir, main page]

Skip to content

Commit 8739a8c

Browse files
georgkaleidopytorchmergebot
authored andcommitted
elastic: do not shutdown rendezvous on leaving workers (#152525)
In #117066, shutdown of the rendezvous was added if a worker shuts down. This is incorrect, because the rendezvous is actually shutdown in [this file](https://github.com/pytorch/pytorch/blob/fa6f9eb2be07f6289d2ab4e781077f7fc75dbe55/torch/distributed/launcher/api.py#L290) but should not be shutdown if a signal is received. See also [this pull request](#67749). #124819 then tried to remediate the situation by fixing the faulty shutdown for the restart case. But this is only triggered if the agent restarts the training, but not if the shutdown of the rendezvous happened before. Removing both these changes restores the original behavior. The rendezvous should only be shutdown if a run completes or fails, not for a single worker leaving. Fixes #150916 Fixes #147064 Pull Request resolved: #152525 Approved by: https://github.com/kiukchung
1 parent 8ac82c3 commit 8739a8c

File tree

3 files changed

+7
-19
lines changed

3 files changed

+7
-19
lines changed

test/distributed/elastic/agent/server/test/api_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,7 @@ def __init__(self, spec):
127127
self.stop_workers_call_count = 0
128128
self.start_workers_call_count = 0
129129

130-
def _stop_workers(
131-
self, worker_group: WorkerGroup, is_restart: bool = False
132-
) -> None:
130+
def _stop_workers(self, worker_group: WorkerGroup) -> None:
133131
# workers are fake, nothing to stop; just clear the rdzv info
134132
worker_group.group_rank = None
135133
worker_group.group_world_size = None

torch/distributed/elastic/agent/server/api.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,7 @@ def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
457457
raise NotImplementedError
458458

459459
@abc.abstractmethod
460-
def _stop_workers(
461-
self, worker_group: WorkerGroup, is_restart: bool = False
462-
) -> None:
460+
def _stop_workers(self, worker_group: WorkerGroup) -> None:
463461
r"""Stop all workers in the given worker group.
464462
465463
Implementors must deal with workers in all states defined by
@@ -477,9 +475,7 @@ def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
477475
raise NotImplementedError
478476

479477
@abc.abstractmethod
480-
def _shutdown(
481-
self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False
482-
) -> None:
478+
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
483479
"""Clean up any resources that were allocated during the agent's work.
484480
485481
Args:
@@ -698,7 +694,7 @@ def _restart_workers(self, worker_group: WorkerGroup) -> None:
698694
"""Restart (stops, rendezvous, starts) all local workers in the group."""
699695
role = worker_group.spec.role
700696
logger.info("[%s] Stopping worker group", role)
701-
self._stop_workers(worker_group, is_restart=True)
697+
self._stop_workers(worker_group)
702698
worker_group.state = WorkerState.STOPPED
703699
self._initialize_workers(worker_group)
704700

torch/distributed/elastic/agent/server/local_elastic_agent.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,8 @@ def _log_watchdog_event(
280280
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
281281
# `torch.distributed.elastic.metrics.prof`.
282282
@prof
283-
def _stop_workers(
284-
self, worker_group: WorkerGroup, is_restart: bool = False
285-
) -> None:
286-
self._shutdown(is_restart=is_restart)
283+
def _stop_workers(self, worker_group: WorkerGroup) -> None:
284+
self._shutdown()
287285

288286
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
289287
# `torch.distributed.elastic.metrics.prof`.
@@ -359,9 +357,7 @@ def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
359357

360358
return self._pcontext.pids()
361359

362-
def _shutdown(
363-
self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False
364-
) -> None:
360+
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
365361
if self._worker_watchdog is not None:
366362
self._worker_watchdog.stop()
367363
self._worker_watchdog = None
@@ -370,8 +366,6 @@ def _shutdown(
370366
self._health_check_server = None
371367
if self._pcontext:
372368
self._pcontext.close(death_sig)
373-
if not is_restart and self._rdzv_handler:
374-
self._rdzv_handler.shutdown()
375369

376370
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
377371
# `torch.distributed.elastic.metrics.prof`.

0 commit comments

Comments
 (0)
0