8000 [dynamic shapes] stop writing Max(*, 1) for strides by pianpwk · Pull Request #150376 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamic shapes] stop writing Max(*, 1) for strides #150376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions c10/core/TensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,10 +923,13 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
sym_shape_meta.strides_.resize(dim_);
if (dim_ > 0) {
const auto last_idx = dim_ - 1;
auto accum = c10::SymInt(1);
sym_shape_meta.strides_[last_idx] = c10::SymInt(1);
for (auto i = last_idx - 1; i >= 0; --i) {
sym_shape_meta.strides_[i] = sym_shape_meta.strides_[i + 1] *
sym_shape_meta.sizes_[i + 1].max(1);
if (TORCH_GUARD_OR_TRUE(sym_shape_meta.sizes_[i + 1].sym_gt(1))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we used to generate max(u0, 1)
and now we generate just u0 unconditionally?

can you explain why this is safe?

accum *= sym_shape_meta.sizes_[i + 1];
}
sym_shape_meta.strides_[i] = accum;
}
}
break;
Expand Down
13 changes: 13 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4406,6 +4406,19 @@ def forward(self, xs, y):
strict=strict,
)

def test_contiguous_unbacked_strides(self):
class Foo(torch.nn.Module):
def forward(self, xs):
u0, u1, u2 = xs.tolist()
return torch.empty(u0, u1, u2).contiguous()

ep = export(Foo(), (torch.tensor([2, 3, 4]),))
node = [node for node in ep.graph.nodes][-2]
val = node.meta["val"]
u0, u1, u2 = val.shape
self.assertEqual(val.stride(0), u1 * u2)
self.assertEqual(val.stride(1), u2)

def test_tolist(self):
class M(torch.nn.Module):
def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ def test_non_overlapping_and_dense_unbacked(self):
cf(
torch.empty_strided(
(2, 3, 1, u0),
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
(3 * u0, u0, u0, 1),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test case is no longer relevant

device="meta",
)
)
Expand Down
7 changes: 3 additions & 4 deletions torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,15 +1638,14 @@ def make_contiguous_strides_for(
if not shape:
return ()

from torch.fx.experimental.symbolic_shapes import is_nested_int
from torch.fx.experimental.symbolic_shapes import guard_or_true, is_nested_int

multiplier: Union[_IntLikeT, int] = 1
strides = []
for l in reversed(shape):
strides.append(multiplier)
multiplier *= (
l if is_nested_int(l) else sym_max(l, 1)
) # type:ignore[assignment]
if is_nested_int(l) or guard_or_true(l >= 1):
multiplier *= l # type:ignore[assignment]

result = tuple(reversed(strides))

Expand Down
Loading
0