-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🐛 Describe the bug
import torch
import os
os.environ['NCCL_DEBUG'] = 'WARN'
from torch import nn
from torch import distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
from torch.distributed._tensor import Replicate
device = 'cuda' if torch.cuda.is_available() else 'cpu'
backend = 'nccl' if device == 'cuda' else 'gloo'
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
dist.init_process_group(backend=backend, world_size=world_size)
if device == 'cuda':
torch.cuda.set_device(local_rank)
device_mesh = init_device_mesh(device, [2, 2], mesh_dim_names=['dp', 'tp'])
tp_mesh = device_mesh['tp']
def print_on_all_rank(string=None):
rank = dist.get_rank()
for i in range(world_size):
if i == rank:
print(f'Global rank: {rank}, tp_rank: {device_mesh["tp"].get_local_rank()}, dp_rank: {device_mesh["dp"].get_local_rank()}')
if string is not None:
print(string)
dist.barrier()
with torch.device(device):
model = nn.Embedding(num_embeddings=4, embedding_dim=1)
# hardcode the weights to visualize and reproduce
model.weight.data[0, 0] = 0.
model.weight.data[1, 0] = 1.
model.weight.data[2, 0] = 2.
model.weight.data[3, 0] = 3.
# apply tp
model = parallelize_module(
model,
tp_mesh,
parallelize_plan=RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Replicate(),
),
)
# apply fsdp2
from torch.distributed._composable.fsdp import fully_shard
model = fully_shard(model, mesh=device_mesh['dp'], reshard_after_forward=True)
# The following line doesn't work with 2.3.1. Use nightly instead
model.unshard()
# print embedding weights on each rank
print_on_all_rank(model.weight)
full_tensor = model.weight.full_tensor()
if rank == 0:
print(full_tensor)
dist.barrier()
I am trying to use TP+FSDP2. However, I find that the behavior of dtensor's full_tensor has incorrect semantic meaning. Here is an example above. There is only an embedding layer with initial parameters [0,1,2,3]. After call TP and FSDP sharding, we observe that
tp_rank 0, dp_rank 0 has 0
tp_rank 1, dp_rank 0 has 2
tp_rank 0, dp_rank 1 has 1
tp_rank 1, dp_rank 1 has 3
This is correct according to the resharding order (first tp, then FSDP).
However, when we call full_tensor, the weights becomes [0, 2, 1, 3], which has different semantic meaning as the sharded embedding layer.
To run the above code,
torchrun --node_rank=0 --nproc_per_node=4 --nnodes=1 --standalone playground.py
This can be reproduced on either CPU or GPU. Thank you for your help!
Here is the output of my run.
Global rank: 0, tp_rank: 0, dp_rank: 0
DTensor(local_tensor=tensor([[0.]], device='cuda:0'), device_mesh=DeviceMesh([[0, 1], [2, 3]], mesh_dim_names=['dp', 'tp']), placements=(Shard(dim=0), Shard(dim=0)))
Global rank: 1, tp_rank: 1, dp_rank: 0
DTensor(local_tensor=tensor([[2.]], device='cuda:1'), device_mesh=DeviceMesh([[0, 1], [2, 3]], mesh_dim_names=['dp', 'tp']), placements=(Shard(dim=0), Shard(dim=0)))
Global rank: 2, tp_rank: 0, dp_rank: 1
DTensor(local_tensor=tensor([[1.]], device='cuda:2'), device_mesh=DeviceMesh([[0, 1], [2, 3]], mesh_dim_names=['dp', 'tp']), placements=(Shard(dim=0), Shard(dim=0)))
Global rank: 3, tp_rank: 1, dp_rank: 1
DTensor(local_tensor=tensor([[3.]], device='cuda:3'), device_mesh=DeviceMesh([[0, 1], [2, 3]], mesh_dim_names=['dp', 'tp']), placements=(Shard(dim=0), Shard(dim=0)))
NCCL version 2.20.5+cuda12.4
tensor([[0.],
[2.],
[1.],
[3.]], device='cuda:0', grad_fn=<_ToTorchTensorBackward>)
Versions
Collecting environment information...
PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.29.5
Libc version: glibc-2.31
Python version: 3.9.2 (default, Feb 28 2021, 17:03:44) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-5.10.135.bsk.6-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H800
GPU 1: NVIDIA H800
GPU 2: NVIDIA H800
GPU 3: NVIDIA H800
Nvidia driver version: 535.129.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 52 bits physical, 57 bits virtual
CPU(s): 180
On-line CPU(s) list: 0-179
Thread(s) per core: 2
Core(s) per socket: 45
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 143
Model name: Intel(R) Xeon(R) Platinum 8457C
Stepping: 8
CPU MHz: 2599.851
BogoMIPS: 5199.70
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 4.2 MiB
L1i cache: 2.8 MiB
L2 cache: 180 MiB
L3 cache: 195 MiB
NUMA node0 CPU(s): 0-89
NUMA node1 CPU(s): 90-179
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] torch==2.3.1
[pip3] torchaudio==2.4.0.dev20240620+cu121
[pip3] torchvision==0.20.0.dev20240620+cu121
[pip3] triton==2.3.1
[conda] Could not collect
### Tasks
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @rohan-varma