8000 [Inductor] short-term fix for needs_fixed_stride_order silent incorre… · pytorch/pytorch@346e0f6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 346e0f6

Browse files
authored
[Inductor] short-term fix for needs_fixed_stride_order silent incorrectness (#133452) (#133888)
This is a low-risk short-term fix for #128084, for the purposes of 2.4.1. The actual fix for that issue is more risky and we'll target 2.5. needs_fixed_stride_order is silently incorrect with args that are mutable because it creates clones of those args, writes into them, and doesn't update the original args. This PR makes it so that needs_fixed_stride_order doesn't apply to inputs that are being mutated. This PR doesn't completely fix the problem, but it makes it less incorrect: most of the time the input already has the correct strides but inductor fails to recognize it, and in those cases writing directly to the input is fine. Test Plan: - new test Pull Request resolved: #133452 Approved by: https://github.com/eellison
1 parent 362a6ca commit 346e0f6

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

test/inductor/test_torchinductor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9845,6 +9845,35 @@ def fn(x):
98459845
# But because our custom op needs fixed layout, the assertions in the custom op will pass
98469846
self.common(fn, (inp,), check_lowp=False)
98479847

9848+
@config.patch(implicit_fallbacks=True)
9849+
def test_mutable_custom_op_fixed_layout(self):
9850+
with torch.library._scoped_library("mylib", "DEF") as lib:
9851+
lib.define(
9852+
"copy_(Tensor(a!) dst, Tensor src) -> ()",
9853+
tags=torch.Tag.needs_fixed_stride_order,
9854+
)
9855+
9856+
@torch.library.impl(lib, "copy_", "Meta")
9857+
def _(dst, src):
9858+
return None
9859+
9860+
@torch.library.impl(lib, "copy_", "CompositeExplicitAutograd")
9861+
def _(dst, src):
9862+
dst.copy_(src)
9863+
9864+
def f(x):
9865+
full_default_3 = torch.full([3], 7.0, device="cpu")
9866+
chunk_cat_default_1 = torch.ops.mylib.copy_.default(full_default_3, x)
9867+
mul_out = torch.mul(full_default_3, full_default_3)
9868+
return mul_out
9869+
9870+
x = torch.arange(3, dtype=torch.float, device="cpu")
9871+
eager_out = f(x)
9872+
9873+
compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True)
9874+
compiled_inductor_out = compiled_inductor_f(x)
9875+
self.assertEqual(compiled_inductor_out, eager_out)
9876+
98489877
@requires_gpu()
98499878
@config.patch(implicit_fallbacks=True)
98509879
def test_custom_op_fixed_layout_channels_last(self):

torch/_inductor/graph.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# mypy: allow-untyped-defs
2+
import functools
23
import itertools
34
import logging
45
import operator
@@ -934,12 +935,13 @@ def get_custom_op_layout_constraints(target, args, kwargs):
934935
# We have to set the current args because call_function will immediately
935936
# evaluate this lowering after creating the fallback, without evaluating
936937
# the layout constraint
937-
args, kwargs = constrain_to_fx_strides(
938-
self.current_node, *args, **kwargs
938+
constrain_fn = functools.partial(
939+
constrain_to_fx_strides, ignore_mutated_args_FIXME=True
939940
)
941+
args, kwargs = constrain_fn(self.current_node, *args, **kwargs)
940942
# Also register the layout constraint so when the fallback
941943
# is used again, we can constrain the args to the same layout
942-
layout_constraint = constrain_to_fx_strides
944+
layout_constraint = constrain_fn
943945
return layout_constraint, args, kwargs
944946

945947
if target not in lowerings:

torch/_inductor/lowering.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2004,13 +2004,44 @@ def require_channels_last(_, *args, **kwargs):
20042004
return args, kwargs
20052005

20062006

2007-
def constrain_to_fx_strides(fx_node, *args, **kwargs):
2007+
def constrain_to_fx_strides(fx_node, *args, ignore_mutated_args_FIXME=False, **kwargs):
20082008
def apply_constraint(arg, fx_arg):
20092009
if isinstance(arg, ir.IRNode):
20102010
stride_order = ir.get_stride_order(fx_arg.meta["val"].stride())
20112011
return ir.ExternKernel.require_stride_order(arg, stride_order)
20122012
return arg
20132013

2014+
# There's a silent incorrectness bug where we if we constrain a mutated arg,
2015+
# we may end up cloning it, writing in-place to the clone, and then using
2016+
# the original value (instead of the cloned value). Our short-term fix for this
2017+
# is to never constrain mutated args; longer term we do want to fix this.
2018+
# https://github.com/pytorch/pytorch/issues/128084
2019+
if ignore_mutated_args_FIXME:
2020+
assert isinstance(fx_node.target, torch._ops.OpOverload)
2021+
schema = fx_node.target._schema
2022+
2023+
def maybe_apply_constraint(schema_arg, arg, fx_arg):
2024+
if schema_arg.alias_info is not None and schema_arg.alias_info.is_write:
2025+
return arg
2026+
return apply_constraint(arg, fx_arg)
2027+
2028+
new_args = []
2029+
new_kwargs = {}
2030+
2031+
for idx, (arg, fx_arg) in enumerate(zip(args, fx_node.args)):
2032+
schema_arg = schema.arguments[idx]
2033+
new_args.append(maybe_apply_constraint(schema_arg, arg, fx_arg))
2034+
2035+
schema_kwargs = {arg.name: arg for arg in schema.arguments}
2036+
2037+
for key in kwargs.keys():
2038+
arg = kwargs[key]
2039+
fx_arg = fx_node.kwargs[key]
2040+
schema_arg = schema_kwargs[key]
2041+
new_kwargs[key] = maybe_apply_constraint(schema_arg, arg, fx_arg)
2042+
2043+
return tuple(new_args), new_kwargs
2044+
20142045
args = tuple(
20152046
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
20162047
)

0 commit comments

Comments
 (0)
0