8000 [CD] Fix slim-wheel nvjit-link import problem (#141063) · pytorch/pytorch@f297571 · GitHub
[go: up one dir, main page]

Skip to content

Commit f297571

Browse files
malfetkit1980
authored andcommitted
[CD] Fix slim-wheel nvjit-link import problem (#141063)
When other toolkit (say CUDA-12.3) is installed and `LD_LIBRARY_PATH` points to there, import torch will fail with ``` ImportError: /usr/local/lib/python3.10/dist-packages/torch/lib/../../nvidia/cusparse/lib/libcusparse.so.12: undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12 ``` It could not be worked around by tweaking rpath, as it also depends on the library load order, which are not guaranteed by any linker. Instead solve this by preloading `nvjitlink` right after global deps are loaded, by running something along the lines of the following ```python if version.cuda in ["12.4", "12.6"]: with open("/proc/self/maps") as f: _maps = f.read() # libtorch_global_deps.so always depends in cudart, check if its installed via wheel if "nvidia/cuda_runtime/lib/libcudart.so" in _maps: # If all abovementioned conditions are met, preload nvjitlink _preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]") ``` Fixes #140797 Pull Request resolved: #141063 Approved by: https://github.com/kit1980 Co-authored-by: Sergii Dymchenko <sdym@meta.com>
1 parent 5c727d5 commit f297571

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

torch/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,24 @@ def _load_global_deps() -> None:
316316

317317
try:
318318
ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
319+
# Workaround slim-wheel CUDA-12.4+ dependency bug in libcusparse by preloading nvjitlink
320+
# In those versions of cuda cusparse depends on nvjitlink, but does not have rpath when
321+
# shipped as wheel, which results in OS picking wrong/older version of nvjitlink library
322+
# if `LD_LIBRARY_PATH` is defined
323+
# See https://github.com/pytorch/pytorch/issues/138460
324+
if version.cuda not in ["12.4", "12.6"]: # type: ignore[name-defined]
325+
return
326+
try:
327+
with open("/proc/self/maps") as f:
328+
_maps = f.read()
329+
# libtorch_global_deps.so always depends in cudart, check if its installed via wheel
330+
if "nvidia/cuda_runtime/lib/libcudart.so" not in _maps:
331+
return
332+
# If all abovementioned conditions are met, preload nvjitlink
333+
_preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]")
334+
except Exception:
335+
pass
336+
319337
except OSError as err:
320338
# Can only happen for wheel with cuda libs as PYPI deps
321339
# As PyTorch is not purelib, but nvidia-*-cu12 is

0 commit comments

Comments
 (0)
0