10000 [torch/elastic] Scale down does not work correctly when agent is killed with SIGINT, SIGTERM · Issue #67742 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[torch/elastic] Scale down does not work correctly when agent is killed with SIGINT, SIGTERM #67742
Closed
@kiukchung

Description

@kiukchung

🐛 Bug

In multi-node training, when one or more elastic agents gets killed with a SIGINT, SIGTERM, (and additionally SIGHUP, and SIGQUIT on windows), scale down does not work even when the surviving agents are compliant with the min - max sizes.

This is because of two things:

  1. When creating the trainer child processes, the agent will self register a termination handler for SIGINT and SIGTERM (and on windows adds in SIGHUP and SIGQUIT as well) that will raise a SignalException (

    signal.signal(signal.SIGTERM, _terminate_process_handler)
    )

  2. In launcher/api.py#launch_agent() (

    rdzv_handler.shutdown()
    ), agent.run() is wrapped in a try-catch-finally block where the finally block calls rdzv_handler.shutdown() on an exception.

The issue with 1) and 2) combined is that when the agent dies with a signal, it produces a python exception (which is not the norm other than SIGINT) and hence rdzv_handler.shutdown() is invoked in the finally block. On the other hand rdzv_handler.shutdown() should ONLY be called during an orderly shutdown (e.g. when the job has finished successfully or it has exhuasted the max_restarts), and SHOULD NOT be called on premature agent failures.

When the shutdown() method on rendezvous is called, this (by design) closes the current rendezvous permanently. Hence disallowing any more scaling events to take place where any agent attempting to join the training job will fail with a RendezvousClosedException.

To Fix

NOTE: Removing termination handlers for SIGTERM, SIGINT IS NOT A VIABLE FIX since this logic was specifically added to ensure (in an os independent way) that agents do not leave orphaned child workers when the agent dies prematurely

What we probably need to do is to catch the SignalException explicitly in

except ChildFailedError:
and NOT call rdzv_handler.shutdown() in this case.

NOTE: We have a scale down unittest (

) which didn't catch this bug since the bug is due to a cross module combination between a newly introduced SignalException in the agent and the existing try-catch-finally block in the launcher/api.py. When we fix this, we should also introduce a unittest specifically testing this cross module dependency.

To Reproduce

Steps to reproduce the behavior:

  1. Run two copies of the agent (each agent runs 1 worker - the worker just sleeps indefinitely) - just make sure you run it from a terminal multiplexer
$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 sleep.py

$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 sleep.py
  1. ctrl+c one of the agents (doesn't matter which)
  2. try to run the same command again (aka restart the agent). This will throw with
Traceback (most recent call last):
  File "/home/ubuntu/.pyenv/versions/venv385/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
    return f(*args, **kwargs)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/run.py", line 719, in main
    run(args)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/run.py", line 710, in run
    elastic_launch(
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 252, in launch_agent
    result = agent.run()
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper
    result = f(*args, **kwargs)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 709, in run
    result = self._invoke_run(role)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 837, in _invoke_run
    self._initialize_workers(self._worker_group)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper
    result = f(*args, **kwargs)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 678, in _initialize_workers
    self._rendezvous(worker_group)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper
    result = f(*args, **kwargs)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 538, in _rendezvous
    store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 1024, in next_rendezvous
    self._op_executor.run(join_op, deadline)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 634, in run
    raise RendezvousClosedError()
torch.distributed.elastic.rendezvous.api.RendezvousClosedError

Expected behavior

When the agent is killed and restarted, it should be able to join the job.

Environment

DOES NOT MATTER

Additional context

  1. [distributed elastic] rendezvous brain split with etcd backend #67616 (comment)
  2. https://discuss.pytorch.org/t/training-process-is-terminated-when-node-fails-for-torch-elastic/135580/5
  3. [distributed elastic] rendezvous brain split with etcd backend #67616 (comment)

Metadata

Metadata

Assignees

Labels

module: elasticRelated to torch.distributed.elasticoncall: r2pAdd this issue/PR to R2P (elastic) oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0