8000 [Inductor] Fix the lowering of squeeze when input is not contiguous · pytorch/pytorch@32a83bb · GitHub
[go: up one dir, main page]

Skip to content

Commit 32a83bb

Browse files
[Inductor] Fix the lowering of squeeze when input is not contiguous
ghstack-source-id: 1680d45 Pull Request resolved: #146746
1 parent fa0592b commit 32a83bb

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

test/inductor/test_unbacked_symints.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,6 @@ def fn(t, start, length):
385385
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
386386
@dynamo_config.patch(capture_dynamic_output_shape_ops=True)
387387
def test_issue_143498(self, device):
388-
if device == "cpu":
389-
raise unittest.SkipTest("CPU Failure")
390-
391388
class Model(torch.nn.Module):
392389
def __init__(self) -> None:
393390
super().__init__()

torch/_inductor/lowering.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,15 +1006,30 @@ def squeeze(x, dim=None):
10061006
dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim)
10071007

10081008
new_shape = []
1009+
new_stride = []
1010+
is_storage_and_layout = ir.is_storage_and_layout(x)
1011+
original_stride = x.get_stride() if is_storage_and_layout else None
1012+
new_offset = x.get_layout().offset if is_storage_and_layout else None
10091013
for d, s in enumerate(x.get_size()):
10101014
if not (
10111015
d in dims
10121016
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True)
10131017
):
10141018
new_shape.append(s)
1019+
if is_storage_and_layout:
1020+
assert isinstance(original_stride, list)
1021+
new_stride.append(original_stride[d])
10151022

10161023
# squeeze does nothing if the size isn't 1
1017-
return view(x, new_shape) if new_shape != x.get_size() else x
1024+
return (
1025+
(
1026+
as_strided(x, new_shape, new_stride, new_offset)
1027+
if is_storage_and_layout
1028+
else view(x, new_shape)
1029+
)
1030+
if new_shape != x.get_size()
1031+
else x
1032+
)
10181033

10191034

10201035
@register_lowering(aten.squeeze_copy, type_promotion_kind=None)

0 commit comments

Comments
 (0)
0