8000 Fix: Replacements can cause runtime assertions to disappear and can cause invalid inductor code. by laithsakka · Pull Request #153661 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix: Replacements can cause runtime assertions to disappear and can cause invalid inductor code. #153661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
88b715b
fix specialized symbols after runtime assertions added
laithsakka May 15, 2025
d8dc4cf
Update on "Fix: specializing symbols after runtime assertions added c…
laithsakka May 16, 2025
cebce0b
Update on "Fix: specializing symbols after runtime assertions added c…
laithsakka May 16, 2025
4721f4b
Update on "Fix: specializing symbols after runtime assertions added c…
laithsakka May 17, 2025
55c5555
Update on "Fix: specializing symbols after runtime assertions added c…
laithsakka May 17, 2025
548b07b
Update on "Fix: specializing symbols after runtime assertions added c…
laithsakka May 17, 2025
65e98aa
Update on "Fix: specializing symbols after runtime assertions added c…
laithsakka May 17, 2025
2511a89
Update on "Fix: Ban replacements for unbacked"
laithsakka May 19, 2025
d2aa4da
Update on "Fix: Ban replacements for unbacked"
laithsakka May 19, 2025
66f3652
Update on "Fix: Replacements can cause runtime assertions to disappea…
laithsakka May 21, 2025
0a73ecb
Update on "Fix: Replacements can cause runtime assertions to disappea…
laithsakka May 22, 2025
9109a4b
Update on "Fix: Replacements can cause runtime assertions to disappea…
laithsakka May 22, 2025
19c093e
Update on "Fix: Replacements can cause runtime assertions to disappea…
laithsakka May 22, 2025
d496e1e
Update on "Fix: Replacements can cause runtime assertions to disappea…
laithsakka May 22, 2025
100180e
Update on "Fix: Replacements can cause runtime assertions to disappea…
laithsakka May 22, 2025
804c7a0
Update on "Fix: Replacements can cause runtime assertions to disappea…
laithsakka May 27, 2025
fca5647
Update on "Fix: Replacements can cause runtime assertions to disappea…
laithsakka May 27, 2025
f38e11e
Update on "Fix: Replacements can cause runtime assertions to disappea…
laithsakka May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions test/dynamo/test_backward_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,17 @@ def fn(x, y):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, s69: "Sym(s21)"):
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_

getitem: "f32[s21]" = l_inputs_[0]
getitem_1: "f32[s21]" = l_inputs_[1]
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None

validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [s69], False)]); getitem = s69 = None
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None

validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None

call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
Expand All @@ -160,14 +163,17 @@ def forward(self, L_inputs_ : list, s69: "Sym(s21)"):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, s69: "Sym(s21)"):
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_

getitem: "f32[s21]" = l_inputs_[0]
getitem_1: "f32[s21]" = l_inputs_[1]
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None

validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [s69], False)]); getitem = s69 = None
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None

validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None

call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
Expand Down Expand Up @@ -242,15 +248,18 @@ def fn(x, y):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter

getitem: "f32[s21]" = l_inputs_[0]
getitem_1: "f32[s21]" = l_inputs_[1]
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None

validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [s69], False)]); getitem = s69 = None
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None

validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None

call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
Expand Down
3 changes: 3 additions & 0 deletions test/export/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,10 @@ def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
linear_weight = self.linear.weight
linear_bias = self.linear.bias
sym_size_int_2 = torch.ops.aten.sym_size.int(x, 1)
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
eq = sym_size_int_2 == 4; sym_size_int_2 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s16, 4) on node 'eq'"); eq = _assert_scalar_default = None
return pytree.tree_unflatten((linear,), self._out_spec)""",
)

Expand Down
1 change: 0 additions & 1 deletion test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4478,7 +4478,6 @@ def forward(self, x, y, z):
ep = export(Foo(), inps, dynamic_shapes=dynamic_shapes)
# values should have no unbacked symbols, bindings should be empty
for node in ep.graph.nodes:
symbols = []
val = node.meta.get("val")
bindings = node.meta.get("unbacked_bindings")
self.assertTrue(
Expand Down
2 changes: 1 addition & 1 deletion test/export/test_unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ def forward(self, x, y):
fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count(
torch.ops.aten.sym_size.int
)
self.assertEqual(fn_count_sym_size(unflat.graph), 1)
self.assertEqual(fn_count_sym_size(unflat.graph), 3)
self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1)
self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0)

Expand Down
68 changes: 59 additions & 9 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3041,9 +3041,19 @@ def test_remove_symbols_without_guarding(self):
self.assertEqual(f"{x_clean.stride()}", "(8, 1)")
self.assertEqual(f"{x_clean.shape}", "torch.Size([5, 8])")


def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
for node in graph.nodes:
if node.name == "arg3_1":
assert node.meta["val"].size()[0] == 2
return graph


class TestUnbacked(TestCase):
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_deferred_neq_assert(self):
@torch.compile(fullgraph=True)
@parametrize("backend", ["inductor", "eager"])
def test_deferred_neq_assert(self, backend):
@torch.compile(fullgraph=True, backend=backend)
def func(a):
torch._check(a.item() != 5)
return a.item() * 10
Expand All @@ -3053,9 +3063,44 @@ def func(a):
with self.assertRaises(RuntimeError):
func(torch.tensor([5]))

# Test a situation where we generate a runtime assert i.e: u1==s1, then we specialize s1
# later on to a constant.
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_post_specialize_runtime_assert1(self, backend):
@torch.compile(dynamic=True, backend=backend)
def func(x, y):
u0 = y.item()
s0 = x.size()[0]
s1 = x.size()[1]
torch._check(u0 + s0 + s1 == 102)
assert s0 == 2
return x * 10

func(torch.rand(2, 50), torch.tensor([50]))
with self.assertRaises(RuntimeError):
func(torch.rand(2, 50), torch.tensor([51]))

@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch(post_grad_custom_pre_pass=custom_pass)
@parametrize("backend", ["inductor", "eager"])
def test_post_specialize_runtime_assert2(self, backend):
@torch.compile(dynamic=True, backend=backend)
def func(x, y):
u0 = y.item()
s0 = x.size()[0]
s1 = x.size()[1]
torch._check(u0 + s0 + s1 == 102)
return x * 10

func(torch.rand(2, 50), torch.tensor([50]))
with self.assertRaises(RuntimeError):
func(torch.rand(2, 50), torch.tensor([51]))

@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_deferred_sym_or_assert(self):
@torch.compile(fullgraph=True)
@parametrize("backend", ["inductor", "eager"])
def test_deferred_sym_or_assert(self, backend):
@torch.compile(fullgraph=True, backend=backend)
def func(a, b):
torch._check(operator.or_(a.item() == 5, b.item() == 5))
return a.item() * 10
Expand All @@ -3074,8 +3119,9 @@ def test_has_free_symbols(self):
self.assertTrue(has_free_symbols(sympy.sympify("a+b")))

@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_deferred_sym_eq_assert(self):
@torch.compile(fullgraph=True)
@parametrize("backend", ["inductor", "eager"])
def test_deferred_sym_eq_assert(self, backend):
@torch.compile(fullgraph=True, backend=backend)
def func(a, b):
torch._check(b.item() == 5)
return a * 10
Expand All @@ -3084,10 +3130,11 @@ def func(a, b):
with self.assertRaises(RuntimeError):
func(torch.tensor([100]), torch.tensor([1]))

@skipIfTorchDynamo("mark_unbacked is not traceable")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_deferred_with_unbacked_input(self):
@torch.compile(fullgraph=True, dynamic=True, backend="inductor")
@parametrize("backend", ["inductor", "eager"])
@skipIfTorchDynamo("mark_unbacked is not traceable")
def test_deferred_with_unbacked_input(self, backend):
@torch.compile(fullgraph=True, dynamic=True, backend=backend)
def func(a, b):
torch._check(a.size()[0] == b.size()[0])
return a * 10
Expand Down Expand Up @@ -3338,5 +3385,8 @@ def func(x, y):
compiled_func(x, torch.tensor([5, 20]))


instantiate_parametrized_tests(TestUnbacked)


if __name__ == "__main__":
run_tests()
14 changes: 5 additions & 9 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,10 +1726,11 @@ def remove_unused_get_attr_nodes(self) -> None:
self.remove_node(node)

def remove_unused_graphargs(self) -> None:
# NB: It's always OK to drop GraphArg for symbols that ended up being
# specialized. You don't even have to make a guard for it, because
# ShapeEnv produce_guards operates on tracked_fakes, which never gets
# pruned. That being said, you'll get marginally better generated
# NB: It's OK to drop GraphArg for symbols that ended up being
# specialized iff they are not used in runtime assertions. You don't
# even have to make a guard for it, because ShapeEnv produce_guards
# operates on tracked_fakes, which never gets pruned.
# That being said, you'll get marginally better generated
# guard code if you promote the guard into a Dynamo guard (since that
# allows for the guard to be done using C++ guards.) If we get
# ShapeEnv guards to go into C++ guards, this will stop being a thing
Expand Down Expand Up @@ -1864,11 +1865,6 @@ def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
)
update_used_symbols(used_symbols, fake)

# Preserve all symbols that appears in original expressions of a deferred_runtime_asserts.
for assertion_list in self.shape_env.deferred_runtime_asserts.values():
for assertion in assertion_list:
used_symbols |= free_symbols(assertion.expr)

# After removing unused graphargs, prune unused binds_symbol
for node in recheck_placeholders:
symbol = placeholder_binds_symbol(node)
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.fx.experimental.symbolic_shapes import (
_get_placeholder_expr,
free_unbacked_symbols,
has_free_symbols,
resolve_unbacked_bindings,
Expand Down Expand Up @@ -1066,7 +1067,7 @@ def placeholder(
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
target = self.qualify_name(target)
if isinstance(example, SymTypes):
expr = example.node.expr
expr = _get_placeholder_expr(example.node)
self.graph_inputs[target] = expr
self.graph_input_names.append(target)
return expr
Expand Down
17 changes: 16 additions & 1 deletion torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5955,10 +5955,13 @@ def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None:

@_lru_cache
def replace(self, expr: _SympyT) -> _SympyT:
"""Apply symbol replacements to any symbols in the given expression"""
"""
Apply symbol replacements to any symbols in the given expression.
"""
replacements = {}
for s in expr.free_symbols:
r = self._find(s)

# Micro-optimization: only do replacements if r and s are different
# Otherwise, xreplace is not a no-op and will trigger expensive
# assumption queries if expr has a relational node.
Expand Down Expand Up @@ -7645,3 +7648,15 @@ def _remove_effect_token_unbacked_bindings(
yield
finally:
node.meta["unbacked_bindings"] = old_bindings


# This helper function is used in passes that insert runtime assertions in the graph.
# When accessing expressions representing input placeholders, we do not apply replacements
# since those inputs should be seen by assertions that use them to be inserted. The only replacement
# that we apply is unbacked renaming.
def _get_placeholder_expr(sym_node: SymNode) -> sympy.Expr:
shape_env = sym_node.shape_env
result = sym_node._expr
if result in shape_env.unbacked_renamings:
return shape_env.unbacked_renamings[result]
return result
6 changes: 5 additions & 1 deletion torch/fx/passes/runtime_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def insert_deferred_runtime_asserts(

from torch._export.passes._node_metadata_hook import _set_node_metadata_hook
from torch.fx.experimental.symbolic_shapes import (
_get_placeholder_expr,
_has_uninterpretable_sympy_function,
CallMethodKey,
cast_symbool_to_symint_guardless,
Expand Down Expand Up @@ -258,6 +259,7 @@ def add_runtime_asserts(ras):
# nodes
with _set_node_metadata_hook(gm, _node_metadata_hook):
res = _sympy_interp(expr_to_proxy, ra.expr).node

graph.call_function(
torch.ops.aten._assert_scalar.default,
# TODO: use ra.msg here, but it's pretty
Expand Down Expand Up @@ -290,7 +292,9 @@ def match_symbol(symint, cb):
if (
isinstance(symint, torch.SymInt)
and isinstance(symint.node, SymNode)
and isinstance(s := symint.node.expr, sympy.Symbol)
and isinstance(
s := _get_placeholder_expr(symint.node), sympy.Symbol
)
and s not in expr_to_proxy
):
with _set_node_metadata_hook(gm, _node_metadata_hook):
Expand Down
Loading
0