8000 [torch/elastic] Scale down does not work correctly when agent is killed with SIGINT, SIGTERM · Issue #67742 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[torch/elastic] Scale down does not work correctly when agent is killed with SIGINT, SIGTERM #67742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
kiukchung opened this issue Nov 3, 2021 · 12 comments
Assignees
Labels
module: elastic Related to torch.distributed.elastic oncall: r2p Add this issue/PR to R2P (elastic) oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@kiukchung
Copy link
Collaborator
kiukchung commented Nov 3, 2021

🐛 Bug

In multi-node training, when one or more elastic agents gets killed with a SIGINT, SIGTERM, (and additionally SIGHUP, and SIGQUIT on windows), scale down does not work even when the surviving agents are compliant with the min - max sizes.

This is because of two things:

  1. When creating the trainer child processes, the agent will self register a termination handler for SIGINT and SIGTERM (and on windows adds in SIGHUP and SIGQUIT as well) that will raise a SignalException (

    signal.signal(signal.SIGTERM, _terminate_process_handler)
    )

  2. In launcher/api.py#launch_agent() (

    rdzv_handler.shutdown()
    ), agent.run() is wrapped in a try-catch-finally block where the finally block calls rdzv_handler.shutdown() on an exception.

The issue with 1) and 2) combined is that when the agent dies with a signal, it produces a python exception (which is not the norm other than SIGINT) and hence rdzv_handler.shutdown() is invoked in the finally block. On the other hand rdzv_handler.shutdown() should ONLY be called during an orderly shutdown (e.g. when the job has finished successfully or it has exhuasted the max_restarts), and SHOULD NOT be called on premature agent failures.

When the shutdown() method on rendezvous is called, this (by design) closes the current rendezvous permanently. Hence disallowing any more scaling events to take place where any agent attempting to join the training job will fail with a RendezvousClosedException.

To Fix

NOTE: Removing termination handlers for SIGTERM, SIGINT IS NOT A VIABLE FIX since this logic was specifically added to ensure (in an os independent way) that agents do not leave orphaned child workers when the agent dies prematurely

What we probably need to do is to catch the SignalException explicitly in

except ChildFailedError:
and NOT call rdzv_handler.shutdown() in this case.

NOTE: We have a scale down unittest (

) which didn't catch this bug since the bug is due to a cross module combination between a newly introduced SignalException in the agent and the existing try-catch-finally block in the launcher/api.py. When we fix this, we should also introduce a unittest specifically testing this cross module dependency.

To Reproduce

Steps to reproduce the behavior:

  1. Run two copies of the agent (each agent runs 1 worker - the worker just sleeps indefinitely) - just make sure you run it from a terminal multiplexer
$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 sleep.py

$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 sleep.py
  1. ctrl+c one of the agents (doesn't matter which)
  2. try to run the same command again (aka restart the agent). This will throw with
Traceback (most recent call last):
  File "/home/ubuntu/.pyenv/versions/venv385/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
    return f(*args, **kwargs)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/run.py", line 719, in main
    run(args)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/run.py", line 710, in run
    elastic_launch(
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 252, in launch_agent
    result = agent.run()
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper
    result = f(*args, **kwargs)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 709, in run
    result = self._invoke_run(role)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 837, in _invoke_run
    self._initialize_workers(self._worker_group)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper
    result = f(*args, **kwargs)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 678, in _initialize_workers
    self._rendezvous(worker_group)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper
    result = f(*args, **kwargs)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 538, in _rendezvous
    store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 1024, in next_rendezvous
    self._op_executor.run(join_op, deadline)
  File "/home/ubuntu/.pyenv/versions/3.8.5/envs/venv385/lib/python3.8/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 634, in run
    raise RendezvousClosedError()
torch.distributed.elastic.rendezvous.api.RendezvousClosedError

Expected behavior

When the agent is killed and restarted, it should be able to join the job.

Environment

DOES NOT MATTER

Additional context

  1. [distributed elastic] rendezvous brain split with etcd backend #67616 (comment)
  2. https://discuss.pytorch.org/t/training-process-is-terminated-when-node-fails-for-torch-elastic/135580/5
  3. [distributed elastic] rendezvous brain split with etcd backend #67616 (comment)
@kiukchung
Copy link
Collaborator Author

@gaocegege I've validated that my expository bugfix PR #67749 solves this bug. But there is additional testing to be done to make sure we are actually propagating errors and recording events correctly, since this module is used in production internally and its been a mammoth effort to get the errors to propagate to the scheduler correctly (for our internal use-cases).

for the time being you can take this patch and help us validate that this indeed fixes the RendezvousClosedExceptions you are observing.

@gaocegege
Copy link
Contributor

SGTM! Thanks for your quick fix. I will take this patch and test again.

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 4, 2021
kiukchung pushed a commit to kiukchung/pytorch that referenced this issue Nov 5, 2021
…tdown() on premature agent failures (pytorch#67749)

Summary:
Pull Request resolved: pytorch#67749

Fixes: pytorch#67742

Test Plan:
Added unittests.

Validated manually:

```
# start agent 0
$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 --monitor_interval 1 test.py

# start agent 1
torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 --monitor_interval 1 test.py

# kill agent 0
CTRL+C (SIGINT) or kill -15 (SIGTERM)

# restart it
torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint localhost:29500 --nnodes 1:2 --nproc_per_node 1 --monitor_interval 1 test.py
```

Reviewed By: cbalioglu

Differential Revision: D32129005

fbshipit-source-id: 4e695d0b3397951d375ecee321add5faf0cfa3ea
@gaocegege
Copy link
Contributor

Hi, will this patch in 1.10.2?

#69100

@drcege
Copy link
drcege commented Apr 24, 2023

@kiukchung @gaocegege Hi, have this been truly fixed?

I have tested the following procedures:

(1) Start two copies of the agent on two nodes, say node.a and node.b

# node.a
$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint node.a:29500 --nnodes 1:2 --nproc_per_node 1 sleep.py

# node.b
$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint node.a:29500 --nnodes 1:2 --nproc_per_node 1 sleep.py

(2) If I ctrl+c to kill the agent on node.b, the process in node.a still survive. Then, I restart the same command on node.b, it can join the group and processes on both node.a and node.b get restarted to scale up.

(3) However, If I kill the agent on node.a, node.b terminate too. The stack trace:

WARNING:torch.distributed.elastic.rendezvous.dynamic_rendezvous:The node 'docker.b_3427_0' has failed to send a keep-alive heartbeat to the rendezvous 'test' due to an error of type RendezvousConnectionError.
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 3494 closing signal SIGTERM
WARNING:torch.distributed.elastic.rendezvous.dynamic_rendezvous:The node 'docker.b_3427_0' has failed to shutdown the rendezvous 'test' due to an error of type RendezvousConnectionError.
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 113, in _call_store
    return getattr(self._store, store_op)(*args, **kwargs)
RuntimeError: Broken pipe

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==1.14.0a0+44dac51', 'console_scripts', 'torchrun')())
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 762, in main
    run(args)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 753, in run
    elastic_launch(
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 237, in launch_agent
    result = agent.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/metrics/api.py", line 129, in wrapper
    result = f(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/agent/server/api.py", line 709, in run
    result = self._invoke_run(role)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/agent/server/api.py", line 881, in _invoke_run
    num_nodes_waiting = rdzv_handler.num_nodes_waiting()
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 1080, in num_nodes_waiting
    self._state_holder.sync()
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 409, in sync
    get_response = self._backend.get_state()
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 73, in get_state
    base64_state: bytes = self._call_store("get", self._key)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 115, in _call_store
    raise RendezvousConnectionError(
torch.distributed.elastic.rendezvous.api.RendezvousConnectionError: The connection to the C10d store has failed. See inner exception for details.

So, my questions are:

Q1: Is step 3 the expected behavior? I thought it didn't matter which one to kill, but clearly the master node is an exception.

Q2: I also wonder if step 2 the expected behavior? When node.b get killed, node.a has no perception or response. Shouldn't it restart to scale down to minimum 1 node? (since --nnodes 1:2)

@kiukchung
Copy link
Collaborator Author
kiukchung commented Apr 24, 2023

yes the issue has been fixed. What you are observing is a consequence of how c10d rdzv backend works (it is not host-node fault tolerant)

The c10d rdzv works by creating a TCPStore on the rdzv_endpoint node + port (in your case node.a:29500). This store contains the state of the job so if node.a is lost then everything is toast. But it has the advantage that you don't need extra setup (things work out of the box with torch)

In practice you can do either:

  1. Run the node you specified in rdzv_endpoint on a more stable node. (e.g. if you're running on spot instances, run node.a on an on-demand instance)
  2. Run with --rdzv_backend=etcd and specify the --rdzv_endpoint=$ETCD_ENDPOINT:$PORT and run etcd on multiple nodes with redundancy.

@drcege
Copy link
drcege commented Apr 25, 2023

@kiukchung Thanks for your explanation!

I also noticed the host node is not fault-tolerant from discussions under other issues. But this is not clearly pointed out in the official docs.

I tried etcd-v2 backend and made it work on single node (still exploring how to achieve redundancy on multiple nodes). There seems to be very little guidance in the docs, too. For example, the code threw an exception complaining about the lack of etcd. I googled and fixed it by pip install python-etcd. Is that the right package?

Finally, is there any hook or entry point where I can inject custom function to save checkpoints before the workers restart? The recommended way of periodically saving checkpoints can lose a lot training efforts, especially for training large GPT models.

@shinytang6
Copy link

@kiukchung @gaocegege Hi, have this been truly fixed?

I have tested the following procedures:

(1) Start two copies of the agent on two nodes, say node.a and node.b

node.a

$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint node.a:29500 --nnodes 1:2 --nproc_per_node 1 sleep.py

node.b

$ torchrun --rdzv_backend c10d --rdzv_id 123 --rdzv_endpoint node.a:29500 --nnodes 1:2 --nproc_per_node 1 sleep.py
(2) If I ctrl+c to kill the agent on node.b, the process in node.a still survive. Then, I restart the same command on node.b, it can join the group and processes on both node.a and node.b get restarted to scale up.

(3) However, If I kill the agent on node.a, node.b terminate too. The stack trace:

WARNING:torch.distributed.elastic.rendezvous.dynamic_rendezvous:The node 'docker.b_3427_0' has failed to send a keep-alive heartbeat to the rendezvous 'test' due to an error of type RendezvousConnectionError.
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 3494 closing signal SIGTERM
WARNING:torch.distributed.elastic.rendezvous.dynamic_rendezvous:The node 'docker.b_3427_0' has failed to shutdown the rendezvous 'test' due to an error of type RendezvousConnectionError.
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 113, in _call_store
return getattr(self._store, store_op)(*args, **kwargs)
RuntimeError: Broken pipe

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 33, in
sys.exit(load_entry_point('torch==1.14.0a0+44dac51', 'console_scripts', 'torchrun')())
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 346, in wrapper
return f(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 762, in main
run(args)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 753, in run
elastic_launch(
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 132, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 237, in launch_agent
result = agent.run()
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/metrics/api.py", line 129, in wrapper
result = f(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/agent/server/api.py", line 709, in run
result = self._invoke_run(role)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/agent/server/api.py", line 881, in _invoke_run
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 1080, in num_nodes_waiting
self._state_holder.sync()
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 409, in sync
get_response = self._backend.get_state()
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 73, in get_state
base64_state: bytes = self._call_store("get", self._key)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 115, in _call_store
raise RendezvousConnectionError(
torch.distributed.elastic.rendezvous.api.RendezvousConnectionError: The connection to the C10d store has failed. See inner exception for details.
So, my questions are:

Q1: Is step 3 the expected behavior? I thought it didn't matter which one to kill, but clearly the master node is an exception.

Q2: I also wonder if step 2 the expected behavior? When node.b get killed, node.a has no perception or response. Shouldn't it restart to scale down to minimum 1 node? (since --nnodes 1:2)

@drcege I encountered the same problem and would like to ask if there is any conclusion about Q2. Thx

@kiukchung
Copy link
Collaborator Author

Hi @shinytang6, thanks for the question.

To summarize your repro:

  1. Start 1 process on 2 nodes (node.a and node.b)
  2. Set the --rdzv_backend as c10d
  3. Set the --rdzv_endpoint to node.a:29500
  4. Restarting the process on node.b works
  5. Trying to restart the process on node.a does not work

This is expected in rdzv_backend=c10d since in this rendezvous backend, the node hosting the --rdzv_endpoint is the single-point-of-failure.

The way the c10d rendezvous backend works is by hosting a c10d store (specifically a TCPStore) on the agent process running on the host specified in --rdzv_endpoint (in this case node.a). Therefore, if the agent on node.a dies, it takes the TCPStore (aka the rendezvous endpoint) with it.

If you need to be resilient to failures of all the participating nodes, you'll want to use a rendezvous backend that uses an external rendezvous endpoint. Take a look at etcd-v2 or etcd rdzv_backends here: https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend

@shinytang6
Copy link
shinytang6 commented Feb 13, 2025

Hi @shinytang6, thanks for the question.

To summarize your repro:

  1. Start 1 process on 2 nodes (node.a and node.b)
  2. Set the --rdzv_backend as c10d
  3. Set the --rdzv_endpoint to node.a:29500
  4. Restarting the process on node.b works
  5. Trying to restart the process on node.a does not work

This is expected in rdzv_backend=c10d since in this rendezvous backend, the node hosting the --rdzv_endpoint is the single-point-of-failure.

The way the c10d rendezvous backend works is by hosting a c10d store (specifically a TCPStore) on the agent process running on the host specified in --rdzv_endpoint (in this case node.a). Therefore, if the agent on node.a dies, it takes the TCPStore (aka the rendezvous endpoint) with it.

If you need to be resilient to failures of all the participating nodes, you'll want to use a rendezvous backend that uses an external rendezvous endpoint. Take a look at etcd-v2 or etcd rdzv_backends here: https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend

@kiukchung Thank you for your reply! But I think there may be some misunderstandings(l understand the truth that the master node cannot be killed or restarted).
My repr steps are:

  1. Start 1 process on 2 nodes (node.a and node.b)
  2. Set the --rdzv_backend as c10d
  3. Set the --rdzv_endpoint to node.a:29500
  4. Kill the process on node.b
  5. What l expected is that node.a can restart because of--nnodes 1:2, but actually node.a has no response

In addition, I also raised one issue(#147064) recording some experiments under pytorch 2.4.1. I encountered similar problems to this issue(SIGINT/SIGTERM/SIGKILL related things) and some other questions. If you have any ideas, please help to take a look.

@kiukchung
Copy link
Collaborator Author
kiukchung commented Feb 13, 2025

Hi @shinytang6

Apologies, I missed your second question

Q2: I also wonder if step 2 the expected behavior? When node.b get killed, node.a has no perception or response. Shouldn't it restart to scale down to minimum 1 node? (since --nnodes 1:2)

Yes, you are correct that when node.b is killed (but not restarted) node.a should:

  1. terminate its current workers (since the world_size has changed)
  2. restart its current workers (new world size = 1x1) and resume training

Are you observing that the agent process on node.a is waiting forever for node.b to re-join?

I can think of a few reasons why node.a won't restart on its own when node.b is terminated.

  1. Not giving node.a enough time to restart. EDIT: default timeouts have been reduced so node.a should restart within 30sec or so.

    • How long have you given node.a to restart on its own? By default it can take some time for node.a to restart on its own.
    • It takes node.a (the master) up to --monitor-intervals seconds (default 30 0.1) to detect a change in the number of participating nodes.
    • Additionally, when there are less than max_nodes waiting to join the job, node.a will wait another --rdzv_config=last_call=SEC number of seconds (default 30) to give any currently restarting nodes to join (this prevents restart thrashes).
    • Finally, there is a --rdzv_config=timeout=SEC seconds (default 30) for the nodes to establish a TCP connection with the c10d store running on node.a.
  2. Not having specified --max-restarts.

    • In torchrun --max-restarts defaults to 0 so on a failure, all agents will shut down with a non-zero exit status (thus failing the job)
    • Given you observation that node.a is just waiting indefinitely, you must've set --max-restarts>0
  3. How are you initializing the process group in your training script?

Could you provide the logs for node.a for the repro in Q2?

@kiukchung
Copy link
Collaborator Author

@shinytang6 UPDATE: I was able to repro this issue on a single node as well (running two agents on the same node). Lets continue the conversation in #147064.

@shinytang6
Copy link

Hi @shinytang6

Apologies, I missed your second question

Q2: I also wonder if step 2 the expected behavior? When node.b get killed, node.a has no perception or response. Shouldn't it restart to scale down to minimum 1 node? (since --nnodes 1:2)

Yes, you are correct that when node.b is killed (but not restarted) node.a should:

  1. terminate its current workers (since the world_size has changed)
  2. restart its current workers (new world size = 1x1) and resume training

Are you observing that the agent process on node.a is waiting forever for node.b to re-join?

I can think of a few reasons why node.a won't restart on its own when node.b is terminated.

  1. Not giving node.a enough time to restart. EDIT: default timeouts have been reduced so node.a should restart within 30sec or so.

    • How long have you given node.a to restart on its own? By default it can take some time for node.a to restart on its own.
    • It takes node.a (the master) up to --monitor-intervals seconds (default 30 0.1) to detect a change in the number of participating nodes.
    • Additionally, when there are less than max_nodes waiting to join the job, node.a will wait another --rdzv_config=last_call=SEC number of seconds (default 30) to give any currently restarting nodes to join (this prevents restart thrashes).
    • Finally, there is a --rdzv_config=timeout=SEC seconds (default 30) for the nodes to establish a TCP connection with the c10d store running on node.a.
  2. Not having specified --max-restarts.

    • In torchrun --max-restarts defaults to 0 so on a failure, all agents will shut down with a non-zero exit status (thus failing the job)
    • Given you observation that node.a is just waiting indefinitely, you must've set --max-restarts>0
  3. How are you initializing the process group in your training script?

Could you provide the logs for node.a for the repro in Q2?

Thank you for your detailed and patient explanation; it‘s really helpful! I tried waiting longer after node.b was killed, and after approximately ten minutes, node.a resumed. It seems related to how I initialized the process group.

dist.init_process_group("nccl")

Image

In my reproduction process mentioned above, the default timeout was ten minutes(and --max-restarts is not 0). After shortening the timeout, node.a was able to detect node.b's departure quickly and resume training.

dist.init_process_group("nccl", timeout=timedelta(seconds=10))

But the issues of RendezvousClosedError and --max-restarts not taking effect still exist(describe in #147064). Additionally, I found that if the timeout inside dist.init_process_group is long enough, the RendezvousClosedError issue is more easily reproduced.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: elastic Related to torch.distributed.elastic oncall: r2p Add this issue/PR to R2P (elastic) oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
6 participants
0