8000 DTensor does not support `nn.init.eye_` · Issue #136946 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

DTensor does not support nn.init.eye_ #136946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ringohoffman opened this issue Sep 29, 2024 · 3 comments
Open

DTensor does not support nn.init.eye_ #136946

ringohoffman opened this issue Sep 29, 2024 · 3 comments
Labels
module: dtensor distributed tensor tag triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ringohoffman
Copy link
Contributor
ringohoffman commented Sep 29, 2024

🐛 Describe the bug

Related:

nn.init.eye_ is not supported on DTensors. I wonder what other inplace nn.init functions are not supported?

# Modified from https://github.com/pytorch/pytorch/tree/main/torch/distributed/_tensor
# to run this file (i.e. dtensor_example.py):
# torchrun --standalone --nnodes=1 --nproc-per-node=1 dtensor_example.py
import os
import torch
from torch import nn
from torch.distributed._tensor import init_device_mesh, Shard, distribute_tensor

mesh = init_device_mesh("cuda", (int(os.environ["WORLD_SIZE"]),))

tensor = torch.rand((3, 3))
my_dtensor = distribute_tensor(tensor, mesh, [Shard(dim=0)])
nn.init.eye_(my_dtensor)
[rank0]: NotImplementedError: Operator aten.eye.m_out does not have a sharding strategy registered.

Versions

torch==2.4.0

cc @wanchaol @tianyu-l @wz337 @XilunWu @d4l3k

@ringohoffman
Copy link
Contributor Author

cc: @wanchaol

@awgu awgu added the module: dtensor distributed tensor tag label Sep 29, 2024
@ringohoffman
Copy link
Contributor Author
# Modified from https://github.com/pytorch/pytorch/tree/main/torch/distributed/_tensor
# to run this file (i.e. dtensor_example.py):
# torchrun --standalone --nnodes=1 --nproc-per-node=1 dtensor_example.py
import os
import torch
from torch.distributed._tensor import DTensor, init_device_mesh, Shard, distribute_tensor


def eye_(t: DTensor) -> None:
    diag = torch.arange(0, min(t.shape), device=t.device)
    diag = distribute_tensor(diag, t.device_mesh, t.placements)

    fill_value = torch.tensor([1.0], dtype=t.dtype, device=t.device)
    fill_value = distribute_tensor(fill_value, t.device_mesh, t.placements)

    t.fill_(0.0)
    t[diag, diag] = fill_value


tensor = torch.rand((3, 3))
mesh = init_device_mesh("cuda", (int(os.environ["WORLD_SIZE"]),))
my_dtensor = distribute_tensor(tensor, mesh, [Shard(dim=0)])
eye_(my_dtensor)

I was going to try to do something like this as a workaround but I'm getting an unexpected error about torch.Tensor even though I converted all components of the operation to DTensor.

[rank0]:   File "/home/matthew/llama3.1/dtensor_example.py", line 24, in <module>
[rank0]:     eye_(my_dtensor)
[rank0]:   File "/home/matthew/llama3.1/dtensor_example.py", line 18, in eye_
[rank0]:     t[diag, diag] = fill_value
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 309, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/distributed/_tensor/_dispatch.py", line 117, in dispatch
[rank0]:     self.sharding_propagator.propagate(op_info)
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/distributed/_tensor/_sharding_prop.py", line 185, in propagate
[rank0]:     output_sharding = self.propagate_op_sharding(op_info.schema)
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/distributed/_tensor/_sharding_prop.py", line 197, in propagate_op_sharding_non_cached
[rank0]:     out_tensor_meta = self._propagate_tensor_meta(op_schema)
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/distributed/_tensor/_sharding_prop.py", line 109, in _propagate_tensor_meta
[rank0]:     fake_out = op_schema.op(*fake_args, **fake_kwargs)
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/_ops.py", line 667, in __call__
[rank0]:     return self_._op(*args, **kwargs)
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 309, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/distributed/_tensor/_dispatch.py", line 115, in dispatch
[rank0]:     op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/distributed/_tensor/_dispatch.py", line 348, in unwrap_to_op_info
[rank0]:     args_schema.append(try_get_replicate_spec(arg, mesh))
[rank0]:   File "/home/matthew/.conda/envs/llama3/lib/python3.10/site-packages/torch/distributed/_tensor/_dispatch.py", line 329, in try_get_replicate_spec
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: aten.index_put_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 1, 2024
@XilunWu XilunWu self-assigned this Dec 16, 2024
@ad8e
Copy link
Contributor
ad8e commented Apr 3, 2025

Workaround (I have not tested this exact code because mine handles modules instead):

def identity_init(tensor):
    from torch.distributed.tensor import DTensor, distribute_tensor
    import torch.distributed as dist

6C3C
    if isinstance(tensor, DTensor):
        full_tensor = torch.empty(tensor.shape, dtype=tensor.dtype, device=tensor.device)
        if dist.get_rank() == 0:
            torch.nn.init.eye_(full_tensor)
        dist.broadcast(
            full_tensor,
            0,
            group=dist.group.WORLD,
        )
        tensor.copy_(distribute_tensor(full_tensor, device_mesh=tensor.device_mesh, placements=tensor.placements))
    else:
        torch.nn.init.eye_(tensor)

@XilunWu XilunWu removed their assignment May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dtensor distributed tensor tag triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0