8000 remove guard_size_oblivious from unbind. by laithsakka · Pull Request #148815 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

remove guard_size_oblivious from unbind. #148815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch.nn.functional as F
from torch import sym_int, SymBool, SymFloat, SymInt
from torch._C import _disabled_torch_function_impl
from torch._dynamo.testing import CompileCounterWithBackend
from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend
from torch._inductor.utils import fresh_cache
from torch.fx.experimental import sym_node
from torch.fx.experimental.proxy_tensor import make_fx
Expand Down Expand Up @@ -3420,6 +3420,24 @@ def func(x, y):
# throws a data dependent error.
compiled_func(x, torch.tensor([5, 20]))

@skipIfTorchDynamo()
def test_unbind_not_dynamic(self):
cnt = CompileCounter()

@torch.compile(fullgraph=True, dynamic=True, backend=cnt)
def func(y):
return y.unbind(dim=2), y * 10

func(torch.ones(5, 6, 7, 8))
self.assertEqual(cnt.frame_count, 1)
# it can be dynamic in all dimentions except dim=2
func(torch.ones(4, 9, 7, 10))
self.assertEqual(cnt.frame_count, 1)

func(torch.ones(5, 6, 8, 8))
func(torch.ones(5, 6, 9, 8))
self.assertEqual(cnt.frame_count, 3)


instantiate_parametrized_tests(TestUnbacked)

Expand Down
7 changes: 4 additions & 3 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,14 +4035,15 @@ def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:

@register_decomposition(aten.unbind)
def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious

dim = utils.canonicalize_dim(t.ndim, dim)
torch._check_index(
len(t.shape) > 0,
lambda: "Dimension specified as 0 but tensor has no dimensions",
)
if guard_size_oblivious(t.shape[dim] == 0):

# Note: t.shape[dim] can't be dynamic or unbacked, even if we use guard_or_false here we will fail
# later in the split since t.shape[dim] control the number of output tensors.
if t.shape[dim] == 0:
return ()
else:
return tuple(
Expand Down
Loading
0