-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[c10d][nccl][cuda] Regression (unspecific cuda launch error) with test_c10d_nncl #136390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Note that the forward fix for Current working theory is something is going wrong in the communicator abort/cleanup, as the following diff seems to make the problem go away:
but moving |
Hi, the "unspecified launch failure" is expected because in this test we injected NaN into the buffer and see if the NaN checker would alarm. The test was put in skip_return_code_checks because the main thread may not always be the first one to catch the CUDA error -- sometimes the watchdog catches it first. In that case, the I don't have a feel that the library code has an issue though. |
@kwen2501 I agree the |
Hmm, let me give it a try. The watchdog may issue |
RE: "May not always" the error becomes very deterministic. |
I printed the cpp stacktrace and it seems the error is thrown from tensor destruction:
I guess the the flow is as follows:
|
I don't know if torch offers an API for resetting the CUDA device.
(The statement is in the It seems, if we'd like to exit this test gracefully, we'd need to change the behavior in Or, is the ask to check the process's return code and confirm it is |
Is there a reason we need to poison the CUDA context in this case?
If that's an acceptable alternative I can try it out and see if it helps |
There isn't a strong preference between |
Tried
|
Opened #136486 in case we want to consider an alternative which incurs the overhead of a sync but instead surfaces a recoverable failure |
See: #153479 |
🐛 Describe the bug
When running
python test/distributed/test_c10d_nccl.py -k test_nan_assert_float16 on a H100x2 platform,
the current nightly (and likely v2.5.0 RC) is producing the following cuda error:
It did not check return code, because:

Tested with ghcr.io/pytorch/pytorch-nightly:2.5.0.dev20240818-cuda12.4-cudnn9-devel , the test did not generate errors other than failing the assertion check.
Bisected to #134300 (cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @atalman @malfet )
i.e. this commit 3645634
To reproduce on a 2xGPU platform:
docker pull ghcr.io/pytorch/pytorch-nightly:2.6.0.dev20240918-cuda12.4-cudnn9-devel
clone pytorch and checkout to the above commit (364563)
run:
python test/distributed/test_c10d_nccl.py -k test_nan_assert_float16
Versions
Bisected to #134300 (cc @kwen2501 @atalman @malfet )
i.e. this commit 3645634
cc @eqy @Aidyn-A @ptrblck
The text was updated successfully, but these errors were encountered: