8000 [TorchElastic] Shutdown behavior appears incorrect and breaks rendezvous · Issue #123678 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[TorchElastic] Shutdown behavior appears incorrect and breaks rendezvous #123678
@eqy

Description

@eqy

🐛 Describe the bug

Forwarding an internal observation that expected rank-leaving followed by rejoining doesn't seem to be functioning as expected after a60b566#diff-6630d974d12b21479b3bc34595b1496fc253d48d703e5089ec654b4d9dd24291R313

Previously:

Exception raised from ncclCommWatchdog at /opt/pytorch/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1339 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x99 (0x7f94f599bdc9 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xf6f04e (0x7f94947e604e in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xca016a (0x7f949451716a in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0xdc253 (0x7f94f54b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #4: <unknown function> + 0x94ac3 (0x7f95049e4ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #5: <unknown function> + 0x126850 (0x7f9504a76850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

[2024-04-04 11:42:17,400] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1433 closing signal SIGTERM
[2024-04-04 11:42:17,405] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1434 closing signal SIGTERM
[2024-04-04 11:42:17,414] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1440 closing signal SIGTERM
[2024-04-04 11:42:17,784] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -6) local_rank: 2 (pid: 1435) of binary: /usr/bin/python
[2024-04-04 11:42:17,784] torch.distributed.elastic.agent.server.api: [INFO] [default] Worker group FAILED. 100/100 attempts left; will restart worker group
[2024-04-04 11:42:17,784] torch.distributed.elastic.agent.server.api: [INFO] [default] Stopping worker group
[2024-04-04 11:42:17,784] torch.distributed.elastic.agent.server.api: [INFO] [default] Rendezvous'ing worker group

Now

[rank0]:[E ProcessGroupNCCL.cpp:588] [Rank 0] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank0]:[E ProcessGroupNCCL.cpp:594] [Rank 0] To avoid data inconsistency, we are taking the entire process down.
[rank0]:[E ProcessGroupNCCL.cpp:1385] [PG 0 Rank 0] NCCL watchdog thread terminated with exception: NCCL error: remote process exited or there was a network error, NCCL version 2.20.5
ncclRemoteError: A call failed possibly due to a network error or a remote process exiting prematurely.
Last error:
socketProgress: Connection closed by remote peer 10.244.3.40<34366>
Exception raised from checkForNCCLErrorsInternal at /opt/pytorch/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1736 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x99 (0x7fe3c6483d89 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::checkForNCCLErrorsInternal(std::vector<std::shared_ptr<c10d::NCCLComm>, std::allocator<std::shared_ptr<c10d::NCCLComm> > > const&) + 0x35d (0x7fe363d378dd in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::WorkNCCL::checkAndSetException() + 0x7b (0x7fe363d37b8b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1d7 (0x7fe363d3f7c7 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10f (0x7fe363d408bf in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0xdc253 (0x7fe3c60b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #6: <unknown function> + 0x94ac3 (0x7fe3d5572ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #7: <unknown function> + 0x126850 (0x7fe3d5604850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG 0 Rank 0] NCCL watchdog thread terminated with exception: NCCL error: remote process exited or there was a network error, NCCL version 2.20.5
ncclRemoteError: A call failed possibly due to a network error or a remote process exiting prematurely.
Last error:
socketProgress: Connection closed by remote peer 10.244.3.40<34366>
Exception raised from checkForNCCLErrorsInternal at /opt/pytorch/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1736 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x99 (0x7fe3c6483d89 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::checkForNCCLErrorsInternal(std::vector<std::shared_ptr<c10d::NCCLComm>, std::allocator<std::shared_ptr<c10d::NCCLComm> > > const&) + 0x35d (0x7fe363d378dd in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::WorkNCCL::checkAndSetException() + 0x7b (0x7fe363d37b8b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1d7 (0x7fe363d3f7c7 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10f (0x7fe363d408bf in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0xdc253 (0x7fe3c60b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #6: <unknown function> + 0x94ac3 (0x7fe3d5572ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #7: <unknown function> + 0x126850 (0x7fe3d5604850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at /opt/pytorch/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1389 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x99 (0x7fe3c6483d89 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xf4d5ae (0x7fe363d6c5ae in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xc76c88 (0x7fe363a95c88 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0xdc253 (0x7fe3c60b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #4: <unknown function> + 0x94ac3 (0x7fe3d5572ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #5: <unknown function> + 0x126850 (0x7fe3d5604850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

[2024-04-04 12:32:00,896] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1450 closing signal SIGTERM
[2024-04-04 12:32:00,907] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1451 closing signal SIGTERM
[2024-04-04 12:32:00,915] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1452 closing signal SIGTERM
[2024-04-04 12:32:00,923] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1453 closing signal SIGTERM
[2024-04-04 12:32:00,926] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1454 closing signal SIGTERM
[2024-04-04 12:32:00,936] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1455 closing signal SIGTERM
[2024-04-04 12:32:00,941] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1456 closing signal SIGTERM
[2024-04-04 12:32:02,777] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -6) local_rank: 0 (pid: 1449) of binary: /usr/bin/python
[2024-04-04 12:32:02,777] torch.distributed.elastic.agent.server.api: [INFO] [default] Worker group FAILED. 100/100 attempts left; will restart worker group
[2024-04-04 12:32:02,777] torch.distributed.elastic.agent.server.api: [INFO] [default] Stopping worker group
[2024-04-04 12:32:02,780] torch.distributed.elastic.agent.server.api: [INFO] [default] Rendezvous'ing worker group
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 834, in main
    run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 825, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 137, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 262, in launch_agent
    result = agent.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/metrics/api.py", line 123, in wrapper
    result = f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/agent/server/api.py", line 736, in run
    result = self._invoke_run(role)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/agent/server/api.py", line 904, in _invoke_run
    self._restart_workers(self._worker_group)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/metrics/api.py", line 123, in wrapper
    result = f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/agent/server/api.py", line 727, in _restart_workers
    self._initialize_workers(worker_group)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/metrics/api.py", line 123, in wrapper
    result = f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/agent/server/api.py", line 708, in _initialize_workers
    self._rendezvous(worker_group)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/metrics/api.py", line 123, in wrapper
    result = f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/agent/server/api.py", line 551, in _rendezvous
    store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 1116, in next_rendezvous
    self._op_executor.run(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py", line 658, in run
    raise RendezvousClosedError()
torch.distributed.elastic.rendezvous.api.RendezvousClosedError

Comment from @sathyanarays:
"The rendezvous session should be closed only when the workers are stopped and should not be closed on restarts. Removing the changes in file torch/distributed/elastic/agent/server/local_elastic_agent.py of this commit makes the rendezvous work as expected. This needs to be fixed up-stream."

Versions

before/after a60b566#diff-6630d974d12b21479b3bc34595b1496fc253d48d703e5089ec654b4d9dd24291R313

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @dzhulgakov

Metadata

Metadata

Assignees

Labels

module: elasticRelated to torch.distributed.elasticoncall: distributedAdd this issue/PR to distributed oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0