8000 Update · pytorch/pytorch@0e955ea · GitHub
[go: up one dir, main page]

Skip to content

Commit 0e955ea

Browse files
Update
[ghstack-poisoned]
1 parent e5f48e5 commit 0e955ea

File tree

2 files changed

+4
-16
lines changed

2 files changed

+4
-16
lines changed

torch/_inductor/ir.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2769,7 +2769,9 @@ def fake_reindex(index): # type: ignore[no-untyped-def]
27692769
# TODO: unbacked should not diverge from backed in determining striding
27702770
# Need to require contiguous here instead of realize, see:
27712771
# https://github.com/pytorch/pytorch/issues/145561
2772-
x = ExternKernel.require_contiguous(x)
2772+
x = ExternKernel.require_exact_strides(
2773+
x, FlexibleLayout.contiguous_strides(x.get_size())
2774+
)
27732775

27742776
storage, old_layout = as_storage_and_layout(x, want_contiguous=True)
27752777
new_layout = FixedLayout(

torch/_inductor/lowering.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,29 +1006,15 @@ def squeeze(x, dim=None):
10061006
dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim)
10071007

10081008
new_shape = []
1009-
new_stride = []
1010-
is_storage_and_layout = ir.is_storage_and_layout(x)
1011-
original_stride = x.get_stride() if is_storage_and_layout else []
1012-
new_offset = x.get_layout().offset if is_storage_and_layout else None
10131009
for d, s in enumerate(x.get_size()):
10141010
if not (
10151011
d in dims
10161012
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True)
10171013
):
10181014
new_shape.append(s)
1019-
if is_storage_and_layout:
1020-
new_stride.append(original_stride[d])
10211015

10221016
# squeeze does nothing if the size isn't 1
1023-
return (
1024-
(
1025-
as_strided(x, new_shape, new_stride, new_offset)
1026-
if is_storage_and_layout
1027-
else view(x, new_shape)
1028-
)
1029-
if new_shape != x.get_size()
1030-
else x
1031-
)
1017+
return view(x, new_shape) if new_shape != x.get_size() else x
10321018

10331019

10341020
@register_lowering(aten.squeeze_copy, type_promotion_kind=None)

0 commit comments

Comments
 (0)
0