8000 [CD] Fix slim-wheel cuda_nvrtc import problem (#145582) · pytorch/pytorch@9752c7c · GitHub
[go: up one dir, main page]

Skip to content

Commit 9752c7c

Browse files
atalmanmalfet
authored andcommitted
[CD] Fix slim-wheel cuda_nvrtc import problem (#145582)
Similar fix as: #144816 Fixes: #145580 Found during testing of #138340 Please note both nvrtc and nvjitlink exist for cuda 11.8, 12.4 and 12.6 hence we can safely remove if statement. Preloading can apply to all supporting cuda versions. CUDA 11.8 path: ``` (.venv) root@b4ffe5c8ac8c:/pytorch/.ci/pytorch/smoke_test# ls /.venv/lib/python3.12/site-packages/torch/lib/../../nvidia/cuda_nvrtc/lib __init__.py __pycache__ libnvrtc-builtins.so.11.8 libnvrtc-builtins.so.12.4 libnvrtc.so.11.2 libnvrtc.so.12 (.venv) root@b4ffe5c8ac8c:/pytorch/.ci/pytorch/smoke_test# ls /.venv/lib/python3.12/site-packages/torch/lib/../../nvidia/nvjitlink/lib __init__.py __pycache__ libnvJitLink.so.12 ``` Test with rc 2.6 and CUDA 11.8: ``` python cudnn_test.py 2.6.0+cu118 ---------------------------------------------SDPA-Flash--------------------------------------------- ALL GOOD ---------------------------------------------SDPA-CuDNN--------------------------------------------- ALL GOOD ``` Thank you @nWEIdia for discovering this issue Pull Request resolved: #145582 Approved by: https://github.com/nWEIdia, https://github.com/eqy, https://github.com/kit1980, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
1 parent 732c499 commit 9752c7c

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

torch/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -312,21 +312,21 @@ def _load_global_deps() -> None:
312312

313313
try:
314314
ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
315-
# Workaround slim-wheel CUDA-12.4+ dependency bug in libcusparse by preloading nvjitlink
316-
# In those versions of cuda cusparse depends on nvjitlink, but does not have rpath when
315+
# Workaround slim-wheel CUDA dependency bugs in cusparse and cudnn by preloading nvjitlink
316+
# and nvrtc. In CUDA-12.4+ cusparse depends on nvjitlink, but does not have rpath when
317317
# shipped as wheel, which results in OS picking wrong/older version of nvjitlink library
318-
# if `LD_LIBRARY_PATH` is defined
319-
# See https://github.com/pytorch/pytorch/issues/138460
320-
if version.cuda not in ["12.4", "12.6"]: # type: ignore[name-defined]
321-
return
318+
# if `LD_LIBRARY_PATH` is defined, see https://github.com/pytorch/pytorch/issues/138460
319+
# Similar issue exist in cudnn that dynamically loads nvrtc, unaware of its relative path.
320+
# See https://github.com/pytorch/pytorch/issues/145580
322321
try:
323322
with open("/proc/self/maps") as f:
324323
_maps = f.read()
325324
# libtorch_global_deps.so always depends in cudart, check if its installed via wheel
326325
if "nvidia/cuda_runtime/lib/libcudart.so" not in _maps:
327326
return
328-
# If all abovementioned conditions are met, preload nvjitlink
327+
# If all abovementioned conditions are met, preload nvjitlink and nvrtc
329328
_preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]")
329+
_preload_cuda_deps("cuda_nvrtc", "libnvrtc.so.*[0-9]")
330330
except Exception:
331331
pass
332332

0 commit comments

Comments
 (0)
0