8000 [Inductor] short-term fix for needs_fixed_stride_order silent incorre… · pytorch/pytorch@1ba7adc · GitHub
[go: up one dir, main page]

Skip to content

Commit 1ba7adc

Browse files
committed
[Inductor] short-term fix for needs_fixed_stride_order silent incorrectness
This is a low-risk short-term fix for #128084, for the purposes of 2.4.1. The actual fix for that issue is more risky and we'll target 2.5. needs_fixed_stride_order is silently incorrect with args that are mutable because it creates clones of those args, writes into them, and doesn't update the original args. This PR makes it so that needs_fixed_stride_order doesn't apply to inputs that are being mutated. This PR doesn't completely fix the problem, but it makes it less incorrect: most of the time the input already has the correct strides but inductor fails to recognize it, and in those cases writing directly to the input is fine. Test Plan: - new test ghstack-source-id: 8d19392 Pull Request resolved: #133452
1 parent cd565bc commit 1ba7adc

File tree

4 files changed

+70
-4
lines changed

4 files changed

+70
-4
lines changed

test/inductor/test_torchinductor.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10571,6 +10571,37 @@ def fn(x):
1057110571
# But because our custom op needs fixed layout, the assertions in the custom op will pass
1057210572
self.common(fn, (inp,), check_lowp=False)
1057310573

10574+
@config.patch(implicit_fallbacks=True)
10575+
def test_mutable_custom_op_fixed_layout(self):
10576+
with torch.library._scoped_library("mylib", "DEF") as lib:
10577+
lib.define(
10578+
"copy_(Tensor(a!) ret, Tensor tensors, int dim) -> ()",
10579+
tags=torch.Tag.needs_fixed_stride_order,
10580+
)
10581+
10582+
@torch.library.impl(lib, "copy_", "Meta")
10583+
def _(ret, tensors, dim):
10584+
return None
10585+
10586+
@torch.library.impl(lib, "copy_", "CPU")
10587+
def _(ret, tensors, dim):
10588+
ret.copy_(tensors)
10589+
10590+
def f(x):
10591+
full_default_3 = torch.full([3], 7.0, device="cpu")
10592+
chunk_cat_default_1 = torch.ops.mylib.copy_.default(
10593+
full_default_3, x, 0
10594+
)
10595+
mul_out = torch.mul(full_default_3, full_default_3)
10596+
return mul_out
10597+
10598+
x = torch.arange(3, dtype=torch.float, device="cpu")
10599+
eager_out = f(x)
10600+
10601+
compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True)
10602+
compiled_inductor_out = compiled_inductor_f(x)
10603+
self.assertEqual(compiled_inductor_out, eager_out)
10604+
1057410605
@requires_gpu()
1057510606
@config.patch(implicit_fallbacks=True)
1057610607
def test_custom_op_fixed_layout_channels_last(self):

torch/_inductor/graph.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -991,12 +991,13 @@ def get_custom_op_layout_constraints(
991991
# We have to OrderedSet the current args because call_function will immediately
992992
# evaluate this lowering after creating the fallback, without evaluating
993993
# the layout constraint
994-
args, kwargs = constrain_to_fx_strides(
995-
self.current_node, *args, **kwargs
994+
constrain_fn = functools.partial(
995+
constrain_to_fx_strides, ignore_mutated_args_FIXME=True
996996
)
997+
args, kwargs = constrain_fn(self.current_node, *args, **kwargs)
997998
# Also register the layout constraint so when the fallback
998999
# is used again, we can constrain the args to the same layout
999-
layout_constraint = constrain_to_fx_strides
1000+
layout_constraint = constrain_fn
10001001
return layout_constraint, args, kwargs
10011002

10021003
if target not in lowerings:

torch/_inductor/lowering.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2067,13 +2067,35 @@ def require_channels_last(_, *args, **kwargs):
20672067
return args, kwargs
20682068

20692069

2070-
def constrain_to_fx_strides(fx_node, *args, **kwargs):
2070+
def constrain_to_fx_strides(fx_node, *args, ignore_mutated_args_FIXME=False, **kwargs):
20712071
def apply_constraint(arg, fx_arg):
20722072
if isinstance(arg, ir.IRNode):
20732073
stride_order = ir.get_stride_order(fx_arg.meta["val"].stride())
20742074
return ir.ExternKernel.require_stride_order(arg, stride_order)
20752075
return arg
20762076

2077+
if ignore_mutated_args_FIXME:
2078+
assert isinstance(fx_node.target, torch._ops.OpOverload)
2079+
schema = fx_node.target._schema
2080+
2081+
new_args = []
2082+
new_kwargs = {}
2083+
schema_args, schema_kwargs = torch._library.utils.schema_args_kwargs(schema)
2084+
for arg, fx_arg, schema_arg in zip(args, fx_node.args, schema_args):
2085+
if schema_arg.alias_info is not None and schema_arg.alias_info.is_write:
2086+
new_args.append(arg)
2087+
else:
2088+
new_args.append(apply_constraint(arg, fx_arg))
2089+
for key in kwargs:
2090+
arg = kwargs[key]
2091+
fx_arg = fx_node.kwargs[key]
2092+
schema_arg = schema_kwargs[key]
2093+
if schema_arg.alias_info is not None and schema_arg.alias_info.is_write:
2094+
new_kwargs[key] = arg
2095+
else:
2096+
new_kwargs[key] = apply_constraint(arg, fx_arg)
2097+
return tuple(new_args), new_kwargs
2098+
20772099
args = tuple(
20782100
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
20792101
)

torch/_library/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,18 @@ def fill_defaults(schema, args, kwargs):
177177
return tuple(new_args), new_kwargs
178178

179179

180+
def schema_args_kwargs(schema):
181+
args = []
182+
kwargs = {}
183+
for i in range(len(schema.arguments)):
184+
info = schema.arguments[i]
185+
if info.kwarg_only:
186+
kwargs[info.name] = info
187+
continue
188+
args.append(info)
189+
return tuple(args), kwargs
190+
191+
180192
def zip_schema(
181193
schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
182194
) -> Iterable[Tuple[_C.Argument, Any]]:

0 commit comments

Comments
 (0)
0