10000 [torch/elastic] unexpected behavior of torch elastic · Issue #147064 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[torch/elastic] unexpected behavior of torch elastic #147064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
shinytang6 opened this issue Feb 13, 2025 · 17 comments
Closed

[torch/elastic] unexpected behavior of torch elastic #147064

shinytang6 opened this issue Feb 13, 2025 · 17 comments
Assignees
Labels
module: elastic Related to torch.distributed.elastic oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@shinytang6
Copy link
shinytang6 commented Feb 13, 2025

🐛 Describe the bug

Hi all, I conducted some simple tests using torch elastic to understand its behavior under node failures, and I encountered several unexpected outcomes against the official doc.

Fault Tolerance & Elasticity test

Master node A command:

$ torchrun  --nnodes=1:2 --nproc-per-node=1 --rdzv-id=0 --rdzv-backend=c10d --rdzv-endpoint=MASTER_ADDR:MASTER_PORT --max-restarts=10 elastic-demo.py

Worker node B command:

$ torchrun  --nnodes=1:2 --nproc-per-node=1 --rdzv-id=0 --rdzv-backend=c10d --rdzv-endpoint=MASTER_ADDR:MASTER_PORT --max-restarts=10 elastic-demo.py

Case 1

  • Both nodes start the task simultaneously, and the training begins normally.
  • After terminating the worker node B task (using ctrl+c or kill -15), master node A hangs and the training still stalls.
  • Restarting the worker node B task sometimes results in an error (torch.distributed.elastic.rendezvous.api.RendezvousClosedError), but it occasionally restarts successfully. This behavior is irregular and the --max-restarts parameter does not seem to take effect; it occurs regardless of increasing or decreasing its value and appears to depend on the timing of the rejoining(not sure about that).

Case 2

  • Both nodes start the task simultaneously, and the training begins normally.
  • After terminating the worker node B task (using kill -9), master node A hangs and the training stalls.
  • Restarting the worker node B task allows the training to restart, but the --max-restarts parameter does not seem to take effect too.

Case 3

  • Both nodes start the task simultaneously, and the training begins normally.
  • After terminating master node A’s task (using ctrl+c, kill -15, or kill -9), the entire training crashes immediately.

The detailed error message:

Traceback (most recent call last):
  File "/opt/conda/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
    return f(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/run.py", line 901, in main
    run(args)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 255, in launch_agent
    result = agent.run()
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 124, in wrapper
    result = f(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 680, in run
    result = self._invoke_run(role)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 829, in _invoke_run
    self._initialize_workers(self._worker_group)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 124, in wrapper
    result = f(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 652, in _initialize_workers
    self._rendezvous(worker_group)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 124, in wrapper
    result = f(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 489, in _rendezvous
    rdzv_info = spec.rdzv_handler.next_rendezvous()
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 1125, in next_rendezvous
    self._op_executor.run(
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 667, in run
    raise RendezvousClosedError
torch.distributed.elastic.rendezvous.api.RendezvousClosedError

So my questions are:

  1. Is the behavior of different signals (SIGINT, SIGTERM, SIGKILL) expected?
  2. Why does the --max-restarts parameter not seem to affect the restart behavior? Is there something I'm missing in the configuration or use of this parameter?

Versions

torch version:

$ pip show torch
Name: torch
Version: 2.4.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /opt/conda/lib/python3.8/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-nccl-cu12, nvidia-nvtx-cu12, sympy, triton, typing-extensions
Required-by: accelerate, bitsandbytes, deepspeed, flash_attn, flash_attn_1, peft, torchaudio, torchpippy, torchvision, transformer_engine, trl

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @dzhulgakov

8000
@shinytang6
Copy link
Author

btw, when l switch to torch==2.2.0,RendezvousClosedError errors no longer exist, but --max-restart still seem to not take effect.

@janeyx99 janeyx99 added oncall: distributed Add this issue/PR to distributed oncall triage queue module: elastic Related to torch.distributed.elastic labels Feb 13, 2025
@kiukchung
Copy link
Collaborator
kiukchung commented Feb 13, 2025

Thanks for the details. Could you clarify in your examples above:

  1. Which process was terminated? Was it the torchrun (aka agent) process or the pytorch worker process (the process that is running elastic-demo.py)
  2. Which process hangs?

The topology of the actual UNIX processes when invoking the torchrun command as shown above on both nodes looks like below:

Image

@kiukchung
Copy link
Collaborator
kiukchung commented Feb 13, 2025

I was able to repro Case 1 (RendezvousClosedError) by running two agents on a single node.

What I ran (on my desktop):

$ torchrun --rdzv_backend=c10d --rdzv_endpoint=localhost:29500 --nnodes=1:2 --nproc_per_node=1 --max_restarts=3 repro.py

# on a different terminal (same host) after `repro.py` starts running 
# to ensure that the first torchrun hosts the c10d store
$ torchrun --rdzv_backend=c10d --rdzv_endpoint=localhost:29500 --nnodes=1:2 --nproc_per_node=1 --max_restarts=3 repro.py

ctrl+c-ing the second torchrun then restarting it, will cause the first to error out with

E0213 12:07:20.086000 870572 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 0 (pid: 876053) of binary: /usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/bin/python3.11
I0213 12:07:20.086000 870572 torch/distributed/elastic/agent/server/api.py:889] [default] Worker group FAILED. 3/3 attempts left; will restart worker group
I0213 12:07:20.086000 870572 torch/distributed/elastic/agent/server/api.py:699] [default] Stopping worker group
I0213 12:07:20.087000 870572 torch/distributed/elastic/agent/server/api.py:677] [default] Rendezvous'ing worker group
Traceback (most recent call last):
  File "/usr/local/google/home/kiuk/.pyenv/versions/venv311/bin/torchrun", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/run.py", line 918, in main
    run(args)
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/run.py", line 909, in run
    elastic_launch(
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 260, in launch_agent
    result = agent.run()
             ^^^^^^^^^^^
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
    result = f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/agent/server/api.py", line 711, in run
    result = self._invoke_run(role)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/agent/server/api.py", line 899, in _invoke_run
    self._restart_workers(self._worker_group)
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
    result = f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/agent/server/api.py", line 702, in _restart_workers
    self._initialize_workers(worker_group)
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
    result = f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/agent/server/api.py", line 683, in _initialize_workers
    self._rendezvous(worker_group)
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
    result = f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/agent/server/api.py", line 500, in _rendezvous
    rdzv_info = spec.rdzv_handler.next_rendezvous()
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 1162, in next_rendezvous
    self._op_executor.run(join_op, deadline, self._get_deadline)
  File "/usr/local/google/home/kiuk/.pyenv/versions/3.11.4/envs/venv311/lib/python3.11/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 676, in run
    raise RendezvousClosedError

Will take a closer look and report back here.

@kiukchung kiukchung self-assigned this Feb 13, 2025
@shinytang6
Copy link
Author
shinytang6 commented Feb 14, 2025

To provide further context, I initialize the process group with the code dist.init_process_group("nccl") (with a default timeout 10 minutes). If I decrease the timeout to 10s, the RendezvousClosedError issue seems to be alleviated; however, it still exists after multiple exiting and rejoining of worker node B.(--max-restarts is set to a very big number)

@shinytang6
Copy link
Author

Thanks for the details. Could you clarify in your examples above:

  1. Which process was terminated? Was it the torchrun (aka agent) process or the pytorch worker process (the process that is running elastic-demo.py)
  2. Which process hangs?

The topology of the actual UNIX processes when invoking the torchrun command as shown above on both nodes looks like below:

Image

  1. l killed the torchrun on worker node B
  2. the hang issue on master node A is resolved, see my comments here([torch/elastic] Scale down does not work correctly when agent is killed with SIGINT, SIGTERM #67742 (comment))

@shinytang6
Copy link
Author

Another discovery regarding the --max-restarts parameter is that if I exit and rejoin worker node B very quickly, it seems possible to completely bypass the --max-restarts limit, allowing for infinite restarts, even if --max-restarts=0 is set. It is only when I exit node B and wait for a while, then node A registers this as a restart.

@kiukchung
Copy link
Collaborator

Another discovery regarding the --max-restarts parameter is that if I exit and rejoin worker node B very quickly, it seems possible to completely bypass the --max-restarts limit, allowing for infinite restarts, even if --max-restarts=0 is set. It is only when I exit node B and wait for a while, then node A registers this as a restart.

thanks for the note, will try a repro and cut a separate bug report for this once I confirm it.

@d4l3k
Copy link
Member
d4l3k commented Feb 21, 2025

I think --max-restarts is enforced on a per agent basis. The idea for that is torchelastic will restart your process N times.

When you exit node B are restarting torchelastic or only the training script?

iirc if you're restarting the whole torchelastic agent/process you instead need to use the scheduler to manage max retries

@fduwjj fduwjj added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 23, 2025
@NikitaShalagin
Copy link

@shinytang6 Have you found any temporary workarounds? Like custom signal handling or probably mokeypathing torch?

@shinytang6
Copy link
Author

@shinytang6 Have you found any temporary workarounds? Like custom signal handling or probably mokeypathing torch?

@NikitaShalagin No, l have to downgrade to pytorch 2.2.0 which torch elastic works.

@shinytang6
Copy link
Author

btw, is there any progress on this issue? cc @kiukchung

@NikitaShalagin
Copy link

@shinytang6 haven't you tried new torch 2.7 to fix this issue?

@shinytang6
Copy link
Author
shinytang6 commented Apr 28, 2025

@shinytang6 haven't you tried new torch 2.7 to fix this issue?

@NikitaShalagin not yet,l only tried 2.2.0(work), 2.3.x & 2.4.x (not work)

@NikitaShalagin
Copy link

@shinytang6 by the way, I've tested on 2.2.0 and 2.2.2
2.2.2 does not work, so probably the root cause is somewhere in between those versions

@shinytang6
Copy link
Author

@shinytang6 by the way, I've tested on 2.2.0 and 2.2.2 2.2.2 does not work, so probably the root cause is somewhere in between those versions

@NikitaShalagin Good to know. I checked the version I tested before, and found that the version worked was torch 2.2.0

@kiukchung
Copy link
Collaborator

btw, is there any progress on this issue? cc @kiukchung

hi @shinytang6, haven't had a chance to dig into it. Will try looking into it this week unless @d4l3k has already looked into it.

@georgkaleido
Copy link
Contributor

From my investigation, two things are at play here:

Case 1&3:
Leaving workers may (depending on whether they are successful with the state update) close the rendezvous.
This is fixed in #152525

Case 2:
Here, the killed worker can not accidentally close the rendezvous (as it is killed), but the agent does not correctly restart a new rendezvous. This is also described in #111646 and fixed in #151220

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: elastic Related to torch.distributed.elastic oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants
0