8000 [c10d][nccl][cuda] Regression (unspecific cuda launch error) with test_c10d_nncl · Issue #136390 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[c10d][nccl][cuda] Regression (unspecific cuda launch error) with test_c10d_nncl #136390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
nWEIdia opened this issue Sep 20, 2024 · 12 comments
Open
Assignees
Labels
module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@nWEIdia
Copy link
Collaborator
nWEIdia commented Sep 20, 2024

🐛 Describe the bug

When running
python test/distributed/test_c10d_nccl.py -k test_nan_assert_float16 on a H100x2 platform,

the current nightly (and likely v2.5.0 RC) is producing the following cuda error:

image

It did not check return code, because:
image

Tested with ghcr.io/pytorch/pytorch-nightly:2.5.0.dev20240818-cuda12.4-cudnn9-devel , the test did not generate errors other than failing the assertion check.

Bisected to #134300 (cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @atalman @malfet )
i.e. this commit 3645634

To reproduce on a 2xGPU platform:

docker pull ghcr.io/pytorch/pytorch-nightly:2.6.0.dev20240918-cuda12.4-cudnn9-devel
clone pytorch and checkout to the above commit (364563)
run:
python test/distributed/test_c10d_nccl.py -k test_nan_assert_float16

Versions

Bisected to #134300 (cc @kwen2501 @atalman @malfet )
i.e. this commit 3645634

cc @eqy @Aidyn-A @ptrblck

@eqy
Copy link
Collaborator
eqy commented Sep 21, 2024

Note that the forward fix for gpuGuard/DeviceGuard in #134357 doesn't seem to fix the issue.

Current working theory is something is going wrong in the communicator abort/cleanup, as the following diff seems to make the problem go away:

diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 878bb7c8be..297ed05318 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -2655,8 +2655,23 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
   op_id_++;

   auto device = getDevice(input);
+  at::cuda::OptionalCUDAGuard gpuGuard(device);
   const auto key = getKeyFromDevice(device);
+
+  if (nanCheck) {
+    auto currentStream = at::cuda::getCurrentCUDAStream(device.index());
+
+    checkForNan(input, currentStream);
+  }
+
   auto ncclComm = getNCCLComm(key, device, opType);
+  // Used many times below, so we stash the unordered_map lookup
+  auto ncclStream = ncclStreams_.at(key);
+
+  // First let NCCL streams wait for input tensors allocation streams
+  syncStream(device, ncclEvents_[key], ncclStream);
+
+

   if (coalescing_state_ & CoalActive) {
     coalescing_state_ |= CoalColl;
@@ -2673,12 +2688,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
     }
   }

-  // Used many times below, so we stash the unordered_map lookup
-  auto ncclStream = ncclStreams_.at(key);
-
-  // First let NCCL streams wait for input tensors allocation streams
-  syncStream(device, ncclEvents_[key], ncclStream);
-
   std::vector<at::Tensor> inputs{input};
   std::vector<at::Tensor> outputs{output};

@@ -2697,12 +2706,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
     work->stashed_for_allocator_safety_->push_back(input);
   }

-  at::cuda::OptionalCUDAGuard gpuGuard(device);
-
-  if (nanCheck) {
-    checkForNan(input, ncclStream);
-  }
-
   // Start event should only be recorded before the ncclGroupStart()
   if (work->timingEnabled_) {
     work->ncclStartEvent_->record(ncclStream);

but moving getNCCLComm before the NaN check causes it to reappear.

@eqy eqy added oncall: distributed Add this issue/PR to distributed oncall triage queue module: nccl Problems related to nccl support labels Sep 21, 2024
@kwen2501
Copy link
Contributor
kwen2501 commented Sep 21, 2024

Hi, the "unspecified launch failure" is expected because in this test we injected NaN into the buffer and see if the NaN checker would alarm.

The test was put in skip_return_code_checks because the main thread may not always be the first one to catch the CUDA error -- sometimes the watchdog catches it first. In that case, the with self.assertRaises(RuntimeError): will fail. But I agree that the test should be improved, maybe by disabling watchdog.

I don't have a feel that the library code has an issue though.

@kwen2501 kwen2501 self-assigned this Sep 21, 2024
@eqy
Copy link
Collaborator
eqy commented Sep 21, 2024

@kwen2501 I agree the unspecified launch failure part is expected---the more concerning part seems to be we see SIGABRT somewhere in shutdown/abort or double free (when running with CUDA_LAUNCH_BLOCKING)

@kwen2501
Copy link
Contributor

Hmm, let me give it a try. The watchdog may issue SIGABT in some cases + settings (it may be the default setting too bc people wanted reliable shutdown rather than watchdog hang.)

@nWEIdia
Copy link
Collaborator Author
nWEIdia commented Sep 22, 2024

Hi, the "unspecified launch failure" is expected because in this test we injected NaN into the buffer and see if the NaN checker would alarm.

The test was put in skip_return_code_checks because the main thread may not always be the first one to catch the CUDA error -- sometimes the watchdog catches it first. In that case, the with self.assertRaises(RuntimeError): will fail. But I agree that the test should be improved, maybe by disabling watchdog.

I don't have a feel that the library code has an issue though.

RE: "May not always"

the error becomes very deterministic.

@kwen2501
Copy link
Contributor

I printed the cpp stacktrace and it seems the error is thrown from tensor destruction:

terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /home/kw2501/local/pytorch/c10/cuda/CUDAException.cpp:43 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0
#7 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) from ??:0
#8 c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::insert_events(c10::cuda::CUDACachingAllocator::Native::(anonymous namespace)::Block*) from CUDACachingAllocator.cpp:0
#9 c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::free(c10::cuda::CUDACachingAllocator::Native::(anonymous namespace)::Block*) from CUDACachingAllocator.cpp:0
#10 c10::cuda::CUDACachingAllocator::Native::local_raw_delete(void*) from crtstuff.c:0
#11 c10::StorageImpl::~StorageImpl() from crtstuff.c:0
#12 c10::TensorImpl::~TensorImpl() [clone .localalias] from TensorImpl.cpp:0
#13 THPVariable_clear(THPVariable*) from python_variable.cpp:0
#14 THPVariable_subclass_dealloc(_object*) from ??:0

I guess the the flow is as follows:

  • NaN checker detects a NaN, and triggers a __trap() from device;
  • The CUDA context is poisoned, any subsequent CUDA call will return with an error. For example, NCCL will raise unspecified launch failure.
  • The program starts to tear down. And when it is deallocating a tensor, the CUDA caching allocator also hits the poisoned CUDA context; but instead of just throwing an error, it decides to go for a "harder" kill -- SIGABT.

@kwen2501
Copy link
Contributor

I don't know if torch offers an API for resetting the CUDA device.
According to CUDA doc:

No more commands can be sent to this device until cudaDeviceReset() is called to reinitialize the device.

(The statement is in the assert section, but I guess it applies to __trap() as well.)

It seems, if we'd like to exit this test gracefully, we'd need to change the behavior in cudaCachingAllocator. I wonder if you have alternative suggestion.

Or, is the ask to check the process's return code and confirm it is SIGABT(6)? We can do that too instead of skipping the return code check. Please advise.

@eqy
Copy link
Collaborator
eqy commented Sep 23, 2024

Is there a reason we need to poison the CUDA context in this case? __trap seems a bit aggressive, compared to e.g., CUDA_KERNEL_ASSERT that is used in other places:

CUDA_KERNEL_ASSERT(-sizes[i] <= index && index < sizes[i] && "index out of bounds");

If that's an acceptable alternative I can try it out and see if it helps

@kwen2501
Copy link
Contributor
kwen2501 commented Sep 23, 2024

There isn't a strong preference between __trap and CUDA_KERNEL_ASSERT, we just want a way to stop the launch of the next CUDA kernel (e.g. NCCL ops). But ideally, the instruction should have good performance because the NaN check is on the fly.

@kwen2501
Copy link
Contributor

Tried CUDA_KERNEL_ASSERT, also got SIGABRT(6).
The stack is similar to that of __trap():

terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /home/kw2501/local/pytorch/c10/cuda/CUDAException.cpp:43 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0
#7 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) from ??:0
#8 c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::insert_events(c10::cuda::CUDACachingAllocator::Native::(anonymous namespace)::Block*) from CUDACachingAllocator.cpp:0
#9 c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::free(c10::cuda::CUDACachingAllocator::Native::(anonymous namespace)::Block*) from CUDACachingAllocator.cpp:0
#10 c10::cuda::CUDACachingAllocator::Native::local_raw_delete(void*) from crtstuff.c:0
#11 c10::StorageImpl::~StorageImpl() from crtstuff.c:0
#12 c10::TensorImpl::~TensorImpl() [clone .localalias] from TensorImpl.cpp:0
#13 THPVariable_clear(THPVariable*) from python_variable.cpp:0
#14 THPVariable_subclass_dealloc(_object*) from ??:0
SIGABRT(6), PID: 4015157, Thread 4015157: 
 frame #0: c10::FatalSignalHandler::stacktraceSignalHandler(bool) + 0x86 (0x7fc7fdb9e8b6 in /home/kw2501/local/pytorch/torch/lib/libc10.so)
frame #1: c10::FatalSignalHandler::fatalSignalHandler(int) + 0x25a (0x7fc7fdb9ef5a in /home/kw2501/local/pytorch/torch/lib/libc10.so)
frame #2: <unknown function> + 0x3e6f0 (0x7fc7ff03e6f0 in /lib64/libc.so.6)
frame #3: <unknown function> + 0x8b99c (0x7fc7ff08b99c in /lib64/libc.so.6)
frame #4: raise + 0x16 (0x7fc7ff03e646 in /lib64/libc.so.6)
frame #5: abort + 0xd3 (0x7fc7ff0287f3 in /lib64/libc.so.6)
frame #6: <unknown function> + 0xa1b21 (0x7fc7e52a1b21 in /lib64/libstdc++.so.6)
frame #7: <unknown function> + 0xad52c (0x7fc7e52ad52c in /lib64/libstdc++.so.6)
frame #8: <unknown function> + 0xac4f9 (0x7fc7e52ac4f9 in /lib64/libstdc++.so.6)
frame #9: __gxx_personality_v0 + 0x9a (0x7fc7e52acc7a in /lib64/libstdc++.so.6)
frame #10: <unknown function> + 0x112d4 (0x7fc7fdf382d4 in /lib64/libgcc_s.so.1)
frame #11: _Unwind_Resume + 0x12e (0x7fc7fdf38d0e in /lib64/libgcc_s.so.1)
frame #12: <unknown function> + 0x10404 (0x7fc7fdf5d404 in /home/kw2501/local/pytorch/torch/lib/libc10_cuda.so)
frame #13: <unknown function> + 0x1db5c (0x7fc7fdf6ab5c in /home/kw2501/local/pytorch/torch/lib/libc10_cuda.so)
frame #14: <unknown function> + 0x556f80 (0x7fc7fcd56f80 in /home/kw2501/local/pytorch/torch/lib/libtorch_python.so)
frame #15: c10::TensorImpl::~TensorImpl() + 0x9 (0x7fc7fdb6a539 in /home/kw2501/local/pytorch/torch/lib/libc10.so)
frame #16: <unknown function> + 0x826b28 (0x7fc7fd026b28 in /home/kw2501/local/pytorch/torch/lib/libtorch_python.so)
frame #17: THPVariable_subclass_dealloc(_object*) + 0x2a6 (0x7fc7fd026e46 in /home/kw2501/local/pytorch/torch/lib/libtorch_python.so)

@eqy
Copy link
Collaborator
eqy commented Sep 23, 2024

Opened #136486 in case we want to consider an alternative which incurs the overhead of a sync but instead surfaces a recoverable failure

@nWEIdia
Copy link
Collaborator Author
nWEIdia commented May 13, 2025

See: #153479
Some setup running this unit test results in hitting a 300sec timeout.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

3 participants
0