-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[torch/elastic] Scale down does not work correctly when agent is killed with SIGINT, SIGTERM #67742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@gaocegege I've validated that my expository bugfix PR #67749 solves this bug. But there is additional testing to be done to make sure we are actually propagating errors and recording events correctly, since this module is used in production internally and its been a mammoth effort to get the errors to propagate to the scheduler correctly (for our internal use-cases). for the time being you can take this patch and help us validate that this indeed fixes the RendezvousClosedExceptions you are observing. |
SGTM! Thanks for your quick fix. I will take this patch and test again. |
…tdown() on premature agent failures (pytorch#67749) Summary: Pull Request resolved: pytorch#67749 Fixes: pytorch#67742 Test Plan: Added unittests. Validated manually: ``` # start agent 0 $ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 --monitor_interval 1 test.py # start agent 1 torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 --monitor_interval 1 test.py # kill agent 0 CTRL+C (SIGINT) or kill -15 (SIGTERM) # restart it torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 --monitor_interval 1 test.py ``` Reviewed By: cbalioglu Differential Revision: D32129005 fbshipit-source-id: 4e695d0b3397951d375ecee321add5faf0cfa3ea
Hi, will this patch in 1.10.2? |
@kiukchung @gaocegege Hi, have this been truly fixed? I have tested the following procedures: (1) Start two copies of the agent on two nodes, say # node.a
$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint node.a:29500 --nnodes 1:2 --nproc_per_node 1 sleep.py
# node.b
$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint node.a:29500 --nnodes 1:2 --nproc_per_node 1 sleep.py (2) If I (3) However, If I kill the agent on WARNING:torch.distributed.elastic.rendezvous.dynamic_rendezvous:The node 'docker.b_3427_0' has failed to send a keep-alive heartbeat to the rendezvous 'test' due to an error of type RendezvousConnectionError.
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 3494 closing signal SIGTERM
WARNING:torch.distributed.elastic.rendezvous.dynamic_rendezvous:The node 'docker.b_3427_0' has failed to shutdown the rendezvous 'test' due to an error of type RendezvousConnectionError.
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 113, in _call_store
return getattr(self._store, store_op)(*args, **kwargs)
RuntimeError: Broken pipe
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 33, in <module>
sys.exit(load_entry_point('torch==1.14.0a0+44dac51', 'console_scripts', 'torchrun')())
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
return f(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 762, in main
run(args)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 753, in run
elastic_launch(
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 132, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 237, in launch_agent
result = agent.run()
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/metrics/api.py", line 129, in wrapper
result = f(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/agent/server/api.py", line 709, in run
result = self._invoke_run(role)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/agent/server/api.py", line 881, in _invoke_run
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 1080, in num_nodes_waiting
self._state_holder.sync()
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 409, in sync
get_response = self._backend.get_state()
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 73, in get_state
base64_state: bytes = self._call_store("get", self._key)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 115, in _call_store
raise RendezvousConnectionError(
torch.distributed.elastic.rendezvous.api.RendezvousConnectionError: The connection to the C10d store has failed. See inner exception for details. So, my questions are: Q1: Is step 3 the expected behavior? I thought it didn't matter which one to kill, but clearly the Q2: I also wonder if step 2 the expected behavior? When |
yes the issue has been fixed. What you are observing is a consequence of how The c10d rdzv works by creating a In practice you can do either:
|
@kiukchung Thanks for your explanation! I also noticed the host node is not fault-tolerant from discussions under other issues. But this is not clearly pointed out in the official docs. I tried Finally, is there any hook or entry point where I can inject custom function to save checkpoints before the workers restart? The recommended way of periodically saving checkpoints can lose a lot training efforts, especially for training large GPT models. |
@drcege I encountered the same problem and would like to ask if there is any conclusion about Q2. Thx |
Hi @shinytang6, thanks for the question. To summarize your repro:
This is expected in The way the If you need to be resilient to failures of all the participating nodes, you'll want to use a rendezvous backend that uses an external rendezvous endpoint. Take a look at |
@kiukchung Thank you for your reply! But I think there may be some misunderstandings(l understand the truth that the master node cannot be killed or restarted).
In addition, I also raised one issue(#147064) recording some experiments under pytorch 2.4.1. I encountered similar problems to this issue(SIGINT/SIGTERM/SIGKILL related things) and some other questions. If you have any ideas, please help to take a look. |
Hi @shinytang6 Apologies, I missed your second question
Yes, you are correct that when
Are you observing that the agent process on I can think of a few reasons why
Could you provide the logs for |
@shinytang6 UPDATE: I was able to repro this issue on a single node as well (running two agents on the same node). Lets continue the conversation in #147064. |
Thank you for your detailed and patient explanation; it‘s really helpful! I tried waiting longer after node.b was killed, and after approximately ten minutes, node.a resumed. It seems related to how I initialized the process group. dist.init_process_group("nccl") In my reproduction process mentioned above, the default timeout was ten minutes(and dist.init_process_group("nccl", timeout=timedelta(seconds=10)) But the issues of |
🐛 Bug
In multi-node training, when one or more elastic agents gets killed with a SIGINT, SIGTERM, (and additionally SIGHUP, and SIGQUIT on windows), scale down does not work even when the surviving agents are compliant with the min - max sizes.
This is because of two things:
When creating the trainer child processes, the agent will self register a termination handler for SIGINT and SIGTERM (and on windows adds in SIGHUP and SIGQUIT as well) that will raise a
SignalException
(pytorch/torch/distributed/elastic/multiprocessing/api.py
Line 233 in cd51d2a
In launcher/api.py#launch_agent() (
pytorch/torch/distributed/launcher/api.py
Line 274 in cd51d2a
agent.run()
is wrapped in a try-catch-finally block where the finally block callsrdzv_handler.shutdown()
on an exception.The issue with 1) and 2) combined is that when the agent dies with a signal, it produces a python exception (which is not the norm other than SIGINT) and hence
rdzv_handler.shutdown()
is invoked in the finally block. On the other handrdzv_handler.shutdown()
should ONLY be called during an orderly shutdown (e.g. when the job has finished successfully or it has exhuasted the max_restarts), and SHOULD NOT be called on premature agent failures.When the shutdown() method on rendezvous is called, this (by design) closes the current rendezvous permanently. Hence disallowing any more scaling events to take place where any agent attempting to join the training job will fail with a
RendezvousClosedException
.To Fix
NOTE: Removing termination handlers for SIGTERM, SIGINT IS NOT A VIABLE FIX since this logic was specifically added to ensure (in an os independent way) that agents do not leave orphaned child workers when the agent dies prematurely
What we probably need to do is to catch the
SignalException
explicitly inpytorch/torch/distributed/launcher/api.py
Line 265 in cd51d2a
rdzv_handler.shutdown()
in this case.NOTE: We have a scale down unittest (
pytorch/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py
Line 810 in cd51d2a
To Reproduce
Steps to reproduce the behavior:
Expected behavior
When the agent is killed and restarted, it should be able to join the job.
Environment
DOES NOT MATTER
Additional context
The text was updated successfully, but these errors were encountered: