From 43bc90d05ff8c6e42063830df6aa7b0ee4102779 Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 28 Feb 2025 17:37:49 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- torch/_inductor/ir.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 03205c43295bff..a0014f03d0dc7e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2798,9 +2798,7 @@ def fake_reindex(index): # type: ignore[no-untyped-def] # TODO: unbacked should not diverge from backed in determining striding # Need to require contiguous here instead of realize, see: # https://github.com/pytorch/pytorch/issues/145561 - x = ExternKernel.require_exact_strides( - x, FlexibleLayout.contiguous_strides(x.get_size()) - ) + x = ExternKernel.require_contiguous(x) storage, old_layout = as_storage_and_layout(x, want_contiguous=True) new_layout = FixedLayout( @@ -5356,7 +5354,7 @@ def require_channels_last_3d(cls, x): # type: ignore[no-untyped-def] @classmethod def require_contiguous(cls, x): # type: ignore[no-untyped-def] - return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) + return cls.require_exact_strides(x, ir.FlexibleLayout.contiguous_strides(x.get_size())) def apply_constraint(self) -> None: pass From 3d8508cfdac28847252e266d438bed6b1a087b16 Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 28 Feb 2025 18:24:14 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torch/_inductor/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a0014f03d0dc7e..607dc473c8bd57 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5354,7 +5354,7 @@ def require_channels_last_3d(cls, x): # type: ignore[no-untyped-def] @classmethod def require_contiguous(cls, x): # type: ignore[no-untyped-def] - return cls.require_exact_strides(x, ir.FlexibleLayout.contiguous_strides(x.get_size())) + return cls.require_exact_strides(x, FlexibleLayout.contiguous_strides(x.get_size())) def apply_constraint(self) -> None: pass