8000 getting different results when adding `torch.Tensor` or python number to a DTensor - Is that expected? · Issue #145218 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

getting different results when adding torch.Tensor or python number to a DTensor - Is that expected? #145218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
thevasudevgupta opened this issue Jan 20, 2025 · 3 comments
Labels
module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@thevasudevgupta
Copy link
thevasudevgupta commented Jan 20, 2025

🐛 Describe the bug

# torchrun --nproc-per-node 2 scripts/dtensor.py

import os
import torch
from torch.distributed.tensor import init_device_mesh, Shard, distribute_tensor

use_tensor = False

rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))

torch.manual_seed(0)
tensor1 = torch.rand(1000, 88)
mesh = init_device_mesh("cpu", (world_size,))

norm1 = torch.linalg.vector_norm(tensor1)
norm1 += torch.tensor(2) if use_tensor else 2
print(f"{norm1}\n")

tensor2 = distribute_tensor(tensor1, mesh, [Shard(dim=0)])
norm2 = torch.linalg.vector_norm(tensor2)
norm2 += torch.tensor(2) if use_tensor else 2

print(f"{norm2.full_tensor()}\n")

setting use_tensor = False gives different results - is that expected?

use_tensor = True works fine and gives same results;

Versions

Collecting environment information...
PyTorch version: 2.5.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.2 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: version 3.30.2
Libc version: N/A

Python version: 3.10.8 (main, Nov 24 2022, 08:08:27) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2

Versions of relevant libraries:
[pip3] flake8==7.1.1
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.25.2
[pip3] pytorch-lightning==2.0.1.post0
[pip3] torch==2.5.1
[pip3] torchaudio==2.0.0.dev20230302
[pip3] torchdata==0.6.1
[pip3] torchmetrics==0.11.4
[pip3] torchtext==0.15.2
[pip3] torchvision==0.19.0
[conda] numpy                     1.25.2                   pypi_0    pypi
[conda] pytorch-lightning         2.0.1.post0              pypi_0    pypi
[conda] torch                     2.5.1                    pypi_0    pypi
[conda] torchaudio                2.0.0.dev20230302          pypi_0    pypi
[conda] torchdata                 0.6.1                    pypi_0    pypi
[conda] torchmetrics              0.11.4                   pypi_0    pypi
[conda] torchtext                 0.15.2                   pypi_0    pypi
[conda] torchvision               0.19.0                   pypi_0    pypi

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l @XilunWu

@mikaylagawarecki mikaylagawarecki added module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jan 21, 2025
@weifengpy
Copy link
Contributor
weifengpy commented Jan 24, 2025

I can confirm. the semantic of partial(sum) seems to be different

  • DTensor (partial(sum)) + torch.tensor(1, device="cuda"): reduction on DTensor first, then add 1
  • DTensor (partial(sum)) + 1: add 1 to local tensor, then reduction
rank=0 scala_result.full_tensor()=tensor(12., device='cuda:0') tensor_result.full_tensor()=tensor(11., device='cuda:0')

repro

# torchrun --nproc-per-node 2 test_dtensor.py
import os
import torch
from torch.distributed.tensor import init_device_mesh, Shard, distribute_tensor
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
torch.cuda.set_device(f'cuda:{rank}')
torch.manual_seed(0)
mesh = init_device_mesh("cuda", (world_size,))
unsharded_tensor = torch.ones(10, 1, device="cuda")
sharded_tensor = distribute_tensor(unsharded_tensor, mesh, [Shard(dim=0)])
norm2 = sharded_tensor.sum()
scala_result = norm2 + 1
tensor_result = norm2 + torch.tensor(1, device="cuda")
print(f"{rank=} {scala_result.full_tensor()=} {tensor_result.full_tensor()=}\n")

cc @tianyu-l regarding DTensor semantics

@tianyu-l
Copy link
Contributor

cc: @XilunWu @wz337

@main-horse
Copy link
Contributor

#147180 possibly related issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

5 participants
0