8000 CUDA deps cannot be preloaded under Bazel · Issue #117350 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

CUDA deps cannot be preloaded under Bazel #117350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
georgevreilly opened this issue Jan 12, 2024 · 16 comments · May be fixed by #137059
Open

CUDA deps cannot be preloaded under Bazel #117350

georgevreilly opened this issue Jan 12, 2024 · 16 comments · May be fixed by #137059
Labels
module: bazel module: build Build system issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@georgevreilly
Copy link
georgevreilly commented Jan 12, 2024

🐛 Describe the bug

If Torch 2.1.0 is used as a dependency with Bazel and rules_python, _preload_cuda_deps fails with OSError: libcufft.so.11: cannot open shared object file: No such file or directory.

Torch 2.1 works fine if you install it and its CUDA dependencies into a single site-packages (e.g., in a virtualenv). It doesn't work with Bazel, as Bazel installs each dependency into its own directory tree, which is appended to PYTHONPATH.

$ bazel test //...
Starting local Bazel server and connecting to it...
INFO: Analyzed 3 targets (66 packages loaded, 15423 targets configured).
INFO: Found 2 targets and 1 test target...
FAIL: //calculator:calc_test (see /pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/bazel-out/k8-fastbuild/testlogs/calculator/calc_test/test.log)
INFO: From Testing //calculator:calc_test:
==================== Test output for //calculator:calc_test:
sys.path=['/pay/src/torch21/calculator',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_filelock/site-packages',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_fsspec/site-packages',
... [40 directories omitted] ...
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_sympy',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_torch',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_triton',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_typing_extensions',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python38.zip',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python3.8',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python3.8/lib-dynload',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python3.8/site-packages']
Traceback (most recent call last):
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_torch/site-packages/torch/__init__.py", line 174, in _load_global_deps
    ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python3.8/ctypes/__init__.py", line 373, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: libcufft.so.11: cannot open shared object file: No such file or directory

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/__main__/calculator/calc_test.py", line 10, in <module>
    import torch  # type: ignore
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_torch/site-packages/torch/__init__.py", line 234, in <module>
    _load_global_deps()
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_torch/site-packages/torch/__init__.py", line 195, in _load_global_deps
    _preload_cuda_deps(lib_folder, lib_name)
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_torch/site-packages/torch/__init__.py", line 161, in _preload_cuda_deps
    ctypes.CDLL(lib_path)
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python3.8/ctypes/__init__.py", line 373, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: libnvJitLink.so.12: cannot open shared object file: No such file or directory
================================================================================
INFO: Elapsed time: 131.313s, Critical Path: 6.78s
INFO: 5 processes: 3 internal, 2 linux-sandbox.
INFO: Build completed, 1 test FAILED, 5 total actions
//calculator:calc_test                                                   FAILED in 0.9s
  /pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/bazel-out/k8-fastbuild/testlogs/calculator/calc_test/test.log

Executed 1 out of 1 test: 1 fails locally.

This can be fixed by slightly reordering cuda_libs in _load_global_deps so that they are topologically sorted.

diff --git a/torch/__init__.py b/torch/__init__.py
index 98c9a43511c..bad6a5f6c3d 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -178,7 +178,11 @@ def _load_global_deps() -> None:
     except OSError as err:
         # Can only happen for wheel with cuda libs as PYPI deps
         # As PyTorch is not purelib, but nvidia-*-cu12 is
+        # These dependencies have been topologically sorted,
+        # so that a lib is loaded after all of its dependencies.
         cuda_libs: Dict[str, str] = {
+            'nvjitlink': 'libnvJitLink.so.*[0-9]',
+            'cusparse': 'libcusparse.so.*[0-9]',
             'cublas': 'libcublas.so.*[0-9]',
             'cudnn': 'libcudnn.so.*[0-9]',
             'cuda_nvrtc': 'libnvrtc.so.*[0-9]',
@@ -187,7 +191,6 @@ def _load_global_deps() -> None:
             'cufft': 'libcufft.so.*[0-9]',
             'curand': 'libcurand.so.*[0-9]',
             'cusolver': 'libcusolver.so.*[0-9]',
-            'cusparse': 'libcusparse.so.*[0-9]',
             'nccl': 'libnccl.so.*[0-9]',
             'nvtx': 'libnvToolsExt.so.*[0-9]',
         }

I have a full repro of the problem, which has a tiny Python app that works in a regular virtualenv, but fails with Bazel. I also created a tool there that patches the Torch wheel. The patched wheel works in Bazel for us.

Related Issues

Versions

$ python collect_env.py
/pay/tmp/venv-torch21/lib/python3.8/site-packages/torch/nn/modules/transformer.py:20: UserWarning: Failed to initialize NumPy: numpy.core.multiarray failed to import (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),
Collecting environment information...
/pay/tmp/venv-torch21/lib/python3.8/site-packages/torch/cuda/__init__.py:138: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11060). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
  return torch._C._cuda_getDeviceCount() > 0

cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             8
On-line CPU(s) list:                0-7
Thread(s) per core:                 2
Core(s) per socket:                 4
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              79
Model name:                         Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
Stepping:                           1
CPU MHz:                            3000.000
CPU max MHz:                        3000.0000
CPU min MHz:                        1200.0000
BogoMIPS:                           4600.02
Hypervisor vendor:                  Xen
Virtualization type:                full
L1d cache:                          128 KiB
L1i cache:                          128 KiB
L2 cache:                           1 MiB
L3 cache:                           45 MiB
NUMA node0 CPU(s):                  0-7
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        KVM: Mitigation: VMX unsupported
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx xsaveopt

Versions of relevant libraries:
[pip3] torch==2.1.0
[pip3] triton==2.1.0
[conda] Could not collect

cc @malfet @seemethere

@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: bazel module: build Build system issues labels Jan 13, 2024
@cullenwren-volair
Copy link

Are there possibly any other workarounds for this other than patching the wheel as you suggested @georgevreilly? Wondering if there is a way to get bazel/pip to place torch as well as all the upstream nvidia pip packages in a site-packages folder.

Also FWIW with python3.10 and torch==2.1.0+cu121, I don't experience this issue however I am now experiencing it with torch==2.2.0+cu121

@araffin
Copy link
araffin commented Feb 22, 2024

For reference, the issue seems to come from nvidia-cusolver which has additional dependencies for nvidia-cusolver-cu12:

PyTorch 2.0, shipped with cuda11 by default

nvidia-cusolver-cu11==11.4.1.48
└── nvidia-cublas-cu11 [required: Any, installed: 11.11.3.6]

PyTorch 2.2.0, shipped with cuda12 by default

nvidia-cusolver-cu12==11.5.4.101
├── nvidia-cublas-cu12 [required: Any, installed: 12.3.4.1]
├── nvidia-cusparse-cu12 [required: Any, installed: 12.2.0.103]
│   └── nvidia-nvjitlink-cu12 [required: Any, installed: 12.3.101]
└── nvidia-nvjitlink-cu12 [required: Any, installed: 12.3.101]

@emrebayramc
Copy link
emrebayramc commented Mar 30, 2024

This is the workaround i applied in github actions for our project

          LIB_DIR=""

          NVJITLINK_LIB_DIR=$(find -L $(pwd) -type d -path '*/nvjitlink/lib' -print -quit)
          if [ -n "$NVJITLINK_LIB_DIR" ]; then
            echo "Adding $NVJITLINK_LIB_DIR to LIB_DIR"
            LIB_DIR="$LIB_DIR:$NVJITLINK_LIB_DIR"
          fi

          CUSPARSE_LIB_DIR=$(find -L $(pwd) -type d -path '*/cusparse/lib' -print -quit)
          if [ -n "$CUSPARSE_LIB_DIR" ]; then
            echo "Adding $CUSPARSE_LIB_DIR to LIB_DIR"
            LIB_DIR="$LIB_DIR:$CUSPARSE_LIB_DIR"
          fi

          if [ -n "$LIB_DIR" ]; then
            LIB_DIR=${LIB_DIR#:}
            echo "Updating LD_LIBRARY_PATH with $LIB_DIR"
            echo "LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$LIB_DIR" >> $GITHUB_ENV
          fi

@mark64
Copy link
mark64 commented Apr 3, 2024

I solved this in our bzlmod-based repo using this:

MODULE.bazel:

# Install pip packages.
pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
pip.parse(
    hub_name = "pypi",
    python_version = PYTHON_VERSION,
    requirements_lock = "//:.requirements_lock.txt",
)
pip.override(
    file = "torch-2.2.1-cp39-cp39-manylinux1_x86_64.whl",
    patch_strip = 1,
    patches = [
        # We have to patch pytorch to fix its dynamic library search code to work
        # with the bazel rules_python directory layout.
        "@//third_party/pytorch:pytorch.patch",
        "@//third_party/pytorch:pytorch_record.patch",
    ],
)
use_repo(pip, "pypi")

(be careful about copy-pasting this, there's sensitive whitespace that won't copy correctly).
pytorch_record.patch
pytorch.patch

@cverrier
Copy link
cverrier commented Jun 3, 2024

@mark64 This seems to be a nice workaround, but as far as I understand this only works for Python 3.9 and PyTorch 2.2.1. Could you elaborate on how to proceed for other versions?

@mark64
Copy link
mark64 commented Jun 10, 2024

@mark64 This seems to be a nice workaround, but as far as I understand this only works for Python 3.9 and PyTorch 2.2.1. Could you elaborate on how to proceed for other versions?

When you change versions, hopefully it's as easy as modifying the file = "torch-2.2.1-cp39-cp39-manylinux1_x86_64.whl" line to specify the right versions.

However, if the patches no longer apply properly, you'll need to re-create them from the new source files. I made pytorch.patch by going into the directory with the package extracted (check your bazel cache folder), copying the source files, modifying the copy, then running diff -u <old file name> <new file name>.

The record patch file was tricky though, because it should have been just the one change to the __init__.py line, but rules_python or pip or something didn't like it (I don't remember which). To get the rest of the changes, I had to:

  1. Create a patch with just the first hunk that modifies __init__.py's hash
  2. Run a build/test that uses pytorch, pulling in the repo and patching files in the process
  3. Look at the error message provided by rules_python. It'll give you a diff with a + and - on random lines. Not sure why exactly.
  4. Make a patch hunk that just does the deletions it asks for
  5. Repeat steps 2 and 3, receiving a patch hunk in the error logs with just the + lines. Add that to your patch file.

@parth-emancro
Copy link
parth-emancro commented Jun 13, 2024

Here's a little sed based bashscript I wrote to remove the dependency on specific pytorch.patch files for torch/__init__.py.
It programmatically finds def _load_global_deps() and reorders the elements in cuda_libs: Dict[str, str]

This works for torch-2.2.2 + python3.10 + bazel 6.4.0, but is agnostic to any of these versions IIUC.
Thanks for the discussion @georgevreilly!

https://gist.github.com/parth-emancro/ac075492a4c55b7aea149ba2aa3d2841

Looks like this was the easiest way to run torch for our bazel+python monorepo given that loading patched .whl files is a huge pain with pipenv / Pipfile + respective python interpreter paths based on me cobbling together tooling from https://github.com/jacksmith15/bazel-python-demo and https://github.com/mvukov/rules_ros2

@mark64 I would've used your solution if not for pipenv - since this script needs to be run manually and there's no good way to give a genrule() access to ~/.cache/bazel files 🥲

@WSUFan
Copy link
WSUFan commented Sep 26, 2024

I solved this in our bzlmod-based repo using this:

MODULE.bazel:

# Install pip packages.
pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
pip.parse(
    hub_name = "pypi",
    python_version = PYTHON_VERSION,
    requirements_lock = "//:.requirements_lock.txt",
)
pip.override(
    file = "torch-2.2.1-cp39-cp39-manylinux1_x86_64.whl",
    patch_strip = 1,
    patches = [
        # We have to patch pytorch to fix its dynamic library search code to work
        # with the bazel rules_python directory layout.
        "@//third_party/pytorch:pytorch.patch",
        "@//third_party/pytorch:pytorch_record.patch",
    ],
)
use_repo(pip, "pypi")

(be careful about copy-pasting this, there's sensitive whitespace that won't copy correctly). pytorch_record.patch pytorch.patch

Hey, there's another approach. Bazel offers the --run_under option, which allows you to specify a prefix command to run before executing tests. You can use this to locate all libraries and preload them. @mark64

@keith
Copy link
keith commented Sep 27, 2024

I debugged this a bit more in our case with 2.4.0. While there is already some logic for preloading the deps, the problem I noticed was that the preload logic only applies if the globals library fails to load with a cuda related library load failure. In our case the cuda library itself wasn't the one that was failing, but the later _C library. Because of this I changed the logic in a new patch to always preload the libs:

diff --git a/torch/__init__.py b/torch/__init__.py
index 1c4d5e4..37d8df6 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -219,15 +219,11 @@ def _load_global_deps() -> None:
 
     split_build_lib_name = LIBTORCH_PKG_NAME
     library_path = find_package_path(split_build_lib_name)
-
-    if library_path:
-        global_deps_lib_path = os.path.join(library_path, 'lib', lib_name)
-    try:
-        ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
-    except OSError as err:
-        # Can only happen for wheel with cuda libs as PYPI deps
-        # As PyTorch is not purelib, but nvidia-*-cu12 is
-        cuda_libs: Dict[str, str] = {
+    # Can only happen for wheel with cuda libs as PYPI deps
+    # As PyTorch is not purelib, but nvidia-*-cu12 is
+    cuda_libs: Dict[str, str] = {
+            'nvjitlink': 'libnvJitLink.so.*[0-9]',
+            'cusparse': 'libcusparse.so.*[0-9]',
             'cublas': 'libcublas.so.*[0-9]',
             'cudnn': 'libcudnn.so.*[0-9]',
             'cuda_nvrtc': 'libnvrtc.so.*[0-9]',
@@ -240,13 +236,13 @@ def _load_global_deps() -> None:
             'nccl': 'libnccl.so.*[0-9]',
             'nvtx': 'libnvToolsExt.so.*[0-9]',
         }
-        is_cuda_lib_err = [lib for lib in cuda_libs.values() if lib.split('.')[0] in err.args[0]]
-        if not is_cuda_lib_err:
-            raise err
-        for lib_folder, lib_name in cuda_libs.items():
-            _preload_cuda_deps(lib_folder, lib_name)
-        ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
 
+    for lib_folder, lib_name in cuda_libs.items():
+        _preload_cuda_deps(lib_folder, lib_name)
+
+    if library_path:
+        global_deps_lib_path = os.path.join(library_path, 'lib', lib_name)
+    ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
     if library_path:
         # loading libtorch_global_deps first due its special logic
         load_shared_libraries(library_path)

I don't think there is any realistic downside to this since the libraries would have just been loaded moments later. And actually in the bazel case potentially missing the vendored libraries and instead loading the system libraries seems a bit risky since they could be of different versions, in practice I don't know how much that matters though.

keith added a commit to keith/pytorch that referenced this issue Sep 30, 2024
Previously cuda deps were only loaded if loading the globals library
failed with a cuda shared lib related error. It's possible the globals
library to load successfully, but then for the torch native libraries to
fail with a cuda error. This now handles both of these cases.

Fixes pytorch#117350
@keith keith linked a pull request Sep 30, 2024 that will close this issue
@keith
Copy link
keith commented Sep 30, 2024

#137059

keith added a commit to keith/pytorch that referenced this issue Nov 19, 2024
Previously cuda deps were only loaded if loading the globals library
failed with a cuda shared lib related error. It's possible the globals
library to load successfully, but then for the torch native libraries to
fail with a cuda error. This now handles both of these cases.

Fixes pytorch#117350
keith added a commit to keith/pytorch that referenced this issue Jan 15, 2025
Previously cuda deps were only loaded if loading the globals library
failed with a cuda shared lib related error. It's possible the globals
library to load successfully, but then for the torch native libraries to
fail with a cuda error. This now handles both of these cases.

Fixes pytorch#117350
@lidingsnyk
Copy link

For whatever reason this is patch is what works for me in torch 2.6 + cuda 12.4

change.patch

Also since we use WORKSPACE instead of MODULE.bazel, I had to reinvent pip.override with repository_rule

keith added a commit to keith/pytorch that referenced this issue Mar 5, 2025
Previously cuda deps were only loaded if loading the globals library
failed with a cuda shared lib related error. It's possible the globals
library to load successfully, but then for the torch native libraries to
fail with a cuda error. This now handles both of these cases.

Fixes pytorch#117350
keith added a commit to keith/pytorch that referenced this issue Mar 24, 2025
Previously cuda deps were only loaded if loading the globals library
failed with a cuda shared lib related error. It's possible the globals
library to load successfully, but then for the torch native libraries to
fail with a cuda error. This now handles both of these cases.

Fixes pytorch#117350
keith added a commit to keith/pytorch that referenced this issue Mar 24, 2025
Previously cuda deps were only loaded if loading the globals library
failed with a cuda shared lib related error. It's possible the globals
library to load successfully, but then for the torch native libraries to
fail with a cuda error. This now handles both of these cases.

Fixes pytorch#117350
@FrankPortman
Copy link

@lidingsnyk can you please share the commands you ran in WORKSPACE to apply this patch? The documentation for applying those patches seems to be very sparse.

@lidingsnyk
Copy link
lidingsnyk commented Mar 28, 2025

@FrankPortman Since pip.override doesn't exist in WORKSPACE, I had to reinvent it in a convoluted way. A new target {any-name-you-see-fit} is created. To make sure the patch takes effect, the new "@{any-name-you-see-fit}:pkg" must be added to py target's deps which also depends on torch. Some target depend on torch through transient dependency. e.g. dependency through transformers package. In this case, just make sure that the new target is in the dependency tree if torch is also in the dependency tree.

This probably works fine...a bunch of common bash command is used in patch_pip.bzl. It should work for linux. Probably works for Mac, or can be fixed easily.

In the WORKSPACE I simply have:

load("//some/path/to:patch_pip.bzl", "patch_pip") # Had to invent this myself

patch_pip(
    name = "{any-name-you-see-fit}",
    src = "@{your torch target}",
    build_file = "//some/path:torch_patch_build_file", # I just finished out the bazel target from responsible for all the torch source files and modified it. See an example below
    patch = "//some/path:torch_init.patch",
)

The patch_pip.bzl that I wrote:
patch_pip.txt

My torch_patch_build_file looks like the following

py_library(
    name = "pkg",
    srcs = glob(
        ["site-packages/**/*.py"],
        allow_empty = True,
    ),
    deps = [
        "@{your torch target}//:pkg",
    ],
    visibility = ["//visibility:public"],
)

I looked at the BUILD file of torch in my bazel runfiles. It has a target pkg. I basically just copied it and make it depend on the original pkg target

@FrankPortman
Copy link

@lidingsnyk TY for the answer. I actually finally ended up understanding how the whl_patches arg to install_deps should work. Posting here in case it simplifies anything for you.

load("@3rdparty_deps//:requirements.bzl", "install_deps")

install_deps(
    whl_patches = {
        "//3rdparty:torch_cuda.patch": '{"whls": ["torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", "torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl"],"patch_strip": 1}',
    },
)

Note the JSONified string.

@lidingsnyk
Copy link

@FrankPortman Thanks! Yes this is much better than my approach... Great comment. It's not easy to figure this out from the official Bazel documentation...

keith added a commit to keith/pytorch that referenced this issue Apr 24, 2025
Previously cuda deps were only loaded if loading the globals library
failed with a cuda shared lib related error. It's possible the globals
library to load successfully, but then for the torch native libraries to
fail with a cuda error. This now handles both of these cases.

Fixes pytorch#117350
@keith
Copy link
keith commented Apr 28, 2025

I realized today that my patch #137059 doesn't solve all cases, depending on your rpaths.

We noticed in our build that it is possible for some libs to be discovered from their system installations:

    133040:	 search cache=/etc/ld.so.cache
    133040:	  trying file=/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12

therefore sidestepping the preload logic, but then later when one of the non-default installed libs is loaded like cudnn, it cannot be found and fails. At this point if the pytorch "preload" logic runs, it will load a potentially mismatched version of the nvidia libraries, from the system ones that have already been loaded..

It's unclear to me how much of a real world issue that would cause, but I assume the nvidia library versions that pytorch pins are important, so it seems to me like we should actually always call the preload logic before trying to load any native library so that the correct paths are loaded (it might still be possible that you could conflict with system libraries if another native extension was loaded before pytorch, but that doesn't seem solvable)

keith added a commit to keith/pytorch that referenced this issue Apr 28, 2025
Previously cuda deps were only loaded if loading the globals library
failed with a cuda shared lib related error. It's possible the globals
library to load successfully, but then for the torch native libraries to
fail with a cuda error. This now handles both of these cases.

Fixes pytorch#117350
keith added a commit to keith/pytorch that referenced this issue Apr 28, 2025
Previously cuda deps were only loaded if loading the globals library
failed with a cuda shared lib related error. It's possible the globals
library to load successfully, but then for the torch native libraries to
fail with a cuda error. This now handles both of these cases.

Fixes pytorch#117350
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: bazel module: build Build system issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
4D90
Projects
None yet
Development

Successfully merging a pull request may close this issue.

0