8000 [DTensor] `Partial(sum)` reductions are wrongly cached (?) · Issue #147180 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[DTensor] Partial(sum) reductions are wrongly cached (?) #147180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Labels
module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@main-horse
Copy link
Contributor
main-horse commented Feb 14, 2025

🐛 Describe the bug

First of all, a very simple motivating example:

# OMP_NUM_THREADS=1 torchrun --nproc-per-node 2 what.py
import os
import torch
from torch.distributed.tensor import DTensor, Partial, init_device_mesh

# Create mesh
mesh = init_device_mesh('cuda', (int(os.environ.get("WORLD_SIZE", "1")),))

# Create random local tensor (different seed on each rank)
randn_local_tensor = torch.randn(4096, 4096, device='cuda')/64

# Create Partial(sum) DTensor from local tensors
dt = DTensor.from_local(randn_local_tensor, mesh, placements=[Partial()])

# Expected: -5*dt != 2*dt (because dt is just random)
assert not (-5*dt.full_tensor() == 2*dt.full_tensor()).all()

# Not expected: when dt is Partial, -5*dt == 2*dt ???
assert (-5*dt == 2*dt).all() # <-- What?

# Exit
torch.distributed.destroy_process_group()

In the above code, we

  1. create a Partial() DTensor from different local randn tensors.
  2. check that -5*dt and 2*dt are not the same when their work is replicated
  3. learn that -5*dt and 2*dt return the same result (???) when the Partial() dt is used.

If we print out the values involved, the issue becomes more clear:

dt=DTensor(local_tensor=tensor[4096, 4096] n=16777216 (64Mb) x∈[-0.083, 0.081] μ=-5.313e-07 σ=0.016 cuda:0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),))
-5 * dt.full_tensor()=tensor[4096, 4096] n=16777216 (64Mb) x∈[-0.568, 0.607] μ=3.048e-05 σ=0.111 cuda:0
 2 * dt.full_tensor()=tensor[4096, 4096] n=16777216 (64Mb) x∈[-0.243, 0.227] μ=-1.219e-05 σ=0.044 cuda:0
-5 * dt=DTensor(local_tensor=tensor[4096, 4096] n=16777216 (64Mb) x∈[-0.568, 0.607] μ=3.048e-05 σ=0.111 cuda:0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Replicate(),))
 2 * dt=DTensor(local_tensor=tensor[4096, 4096] n=16777216 (64Mb) x∈[-0.568, 0.607] μ=3.048e-05 σ=0.111 cuda:0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Replicate(),))

Somehow, the result of -5*dt is cached and reused as the return value for 2*dt...

The following also returns true:

assert (-5*dt).to_local().data_ptr() == (2*dt).to_local().data_ptr()

I do not know how to debug what is happening further.

Versions

Collecting environment information...
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.8.0-49-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 550.127.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        52 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               208
On-line CPU(s) list:                  0-207
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8480+
CPU family:                           6
Model:                                143
Thread(s) per core:                   2
Core(s) per socket:                   52
Socket(s):                            2
Stepping:                             8
BogoMIPS:                             4000.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            6.5 MiB (208 instances)
L1i cache:                            6.5 MiB (208 instances)
L2 cache:                             416 MiB (104 instances)
L3 cache:                             32 MiB (2 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-103
NUMA node1 CPU(s):                    104-207
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] lovely-numpy==0.2.13
[pip3] numpy==2.2.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.23.4
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.6.0
[pip3] triton==3.2.0
[conda] Could not collect

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l @XilunWu

@jbschlosser jbschlosser added oncall: distributed Add this issue/PR to distributed oncall triage queue module: dtensor distributed tensor tag labels Feb 14, 2025
@fduwjj fduwjj added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 23, 2025
@XilunWu XilunWu removed their assignment May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants
0