diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 47af1d91793f8f..c6244bc81ed57b 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -923,10 +923,13 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) { sym_shape_meta.strides_.resize(dim_); if (dim_ > 0) { const auto last_idx = dim_ - 1; + auto accum = c10::SymInt(1); sym_shape_meta.strides_[last_idx] = c10::SymInt(1); for (auto i = last_idx - 1; i >= 0; --i) { - sym_shape_meta.strides_[i] = sym_shape_meta.strides_[i + 1] * - sym_shape_meta.sizes_[i + 1].max(1); + if (TORCH_GUARD_OR_TRUE(sym_shape_meta.sizes_[i + 1].sym_gt(1))) { + accum *= sym_shape_meta.sizes_[i + 1]; + } + sym_shape_meta.strides_[i] = accum; } } break; diff --git a/test/export/test_export.py b/test/export/test_export.py index 5e7d9a436e3d71..844be1e0ffccf9 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -4406,6 +4406,19 @@ def forward(self, xs, y): strict=strict, ) + def test_contiguous_unbacked_strides(self): + class Foo(torch.nn.Module): + def forward(self, xs): + u0, u1, u2 = xs.tolist() + return torch.empty(u0, u1, u2).contiguous() + + ep = export(Foo(), (torch.tensor([2, 3, 4]),)) + node = [node for node in ep.graph.nodes][-2] + val = node.meta["val"] + u0, u1, u2 = val.shape + self.assertEqual(val.stride(0), u1 * u2) + self.assertEqual(val.stride(1), u2) + def test_tolist(self): class M(torch.nn.Module): def forward(self, x): diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index a3458efbe65b2b..c023ca91ae74ae 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -885,7 +885,7 @@ def test_non_overlapping_and_dense_unbacked(self): cf( torch.empty_strided( (2, 3, 1, u0), - (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), + (3 * u0, u0, u0, 1), device="meta", ) ) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index e8339b789f5442..cc4c94e7df447b 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -1638,15 +1638,14 @@ def make_contiguous_strides_for( if not shape: return () - from torch.fx.experimental.symbolic_shapes import is_nested_int + from torch.fx.experimental.symbolic_shapes import guard_or_true, is_nested_int multiplier: Union[_IntLikeT, int] = 1 strides = [] for l in reversed(shape): strides.append(multiplier) - multiplier *= ( - l if is_nested_int(l) else sym_max(l, 1) - ) # type:ignore[assignment] + if is_nested_int(l) or guard_or_true(l >= 1): + multiplier *= l # type:ignore[assignment] result = tuple(reversed(strides))