8000 Pre-dispatch export doesn't work with non-param/buffer tensor subclasses · Issue #153387 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Pre-dispatch export doesn't work with non-param/buffer tensor subclasses #153387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
angelayi opened this issue May 12, 2025 · 0 comments
Open
Assignees
Labels
export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step oncall: export oncall: pt2

Comments

@angelayi
Copy link
Contributor
angelayi commented May 12, 2025

🐛 Describe the bug

This piece of code doesn't work:

   def test_tensor_subclass(self):
        from torch.utils._python_dispatch import return_and_correct_aliasing
        class MooTensor(torch.Tensor):
            @staticmethod
            def __new__(cls, elem1):
                shape = elem1.shape
                kwargs = {}
                kwargs["strides"] = elem1.stride()
                kwargs["storage_offset"] = elem1.storage_offset()
                kwargs["device"] = elem1.device
                kwargs["layout"] = elem1.layout
                kwargs["requires_grad"] = elem1.requires_grad
                kwargs["dtype"] = elem1.dtype
                return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
            def __init__(self, elem1):
                self.elem1 = elem1
            def get_elem(self):
                return self.elem1
            def __repr__(self):
                inner_repr_1 = repr(self.elem1)
                return f"MooTensor({inner_repr_1})"
            def __tensor_flatten__(self):
                return ["elem1"], None
            @staticmethod
            def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
                elem1 = inner_tensors["elem1"]
                out = MooTensor(elem1)
                return out
            @classmethod
            def __torch_dispatch__(cls, func, types, args, kwargs):
                # Don't use this tensor with view ops
                if kwargs is None:
                    kwargs = {}
                args_inner_1 = pytree.tree_map_only(
                    MooTensor, lambda x: x.elem1, args
                )
                kwargs_inner_1 = pytree.tree_map_only(
                    MooTensor, lambda x: x.elem1, kwargs
                )
                out_inner_1 = func(*args_inner_1, **kwargs_inner_1)
                out_inner_flat_1, spec = pytree.tree_flatten(out_inner_1)
        
                if func.is_view:
                    new_out = pytree.tree_unflatten(
                        (MooTensor(tensor1) for tensor1 in out_inner_flat_1),
                        spec,
                    )
                    return return_and_correct_aliasing(func, args, kwargs, new_out)
                return pytree.tree_unflatten(out_inner_flat_1, spec)
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.p1 = MooTensor(torch.ones(2, 2))
            def forward(self, x):
                return x + self.p1
            
        ep = torch.export._trace._export(M(), (torch.randn(2, 2),), strict=False, pre_dispatch=True)
        print(ep)
        print(ep.module()(torch.ones(2, 2) * 2))
        print(M()(torch.ones(2, 2) * 2))

I'm erroring with the following: P1809463696

  File "/data/users/angelayi/pytorch/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/data/users/angelayi/pytorch/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
  File "/data/users/angelayi/pytorch/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/data/users/angelayi/pytorch/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/angelayi/pytorch/torch/nn/modules/module.py", line 1766, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/angelayi/pytorch/moo.py", line 345, in forward
    res = x + self.p1
  File "/data/users/angelayi/pytorch/torch/fx/experimental/proxy_tensor.py", line 1326, in __torch_function__
    return func(*args, **kwargs)
  File "/data/users/angelayi/pytorch/torch/fx/experimental/proxy_tensor.py", line 1373, in __torch_function__
    return func(*args, **kwargs)
  File "/data/users/angelayi/pytorch/torch/_export/non_strict_utils.py", line 967, in __torch_function__
    return func(*args, **kwargs)
RuntimeError: Unable to cast NotImplemented to Tensor

The code works if I wrap self.p1 with torch.nn.Parameter or torch.nn.Buffer. It also works if I change the _export call to be pre_dispatch=False. I also get the same error if I construct a tensor subclass within the forward function.

Versions

main

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @suo @ydwu4

@avikchaudhuri avikchaudhuri added the export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step label May 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step oncall: export oncall: pt2
Projects
None yet
Development

No branches or pull requests

3 participants
0