8000 fix specialized symbols after runtime assertions added · pytorch/pytorch@8d0ed30 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8d0ed30

Browse files
committed
fix specialized symbols after runtime assertions added
ghstack-source-id: 2b24cb6 Pull Request resolved: #153661
1 parent 004dad4 commit 8d0ed30

File tree

5 files changed

+65
-12
lines changed

5 files changed

+65
-12
lines changed

test/test_dynamic_shapes.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3005,6 +3005,15 @@ def test_remove_symbols_without_guarding(self):
30053005
self.assertEqual(f"{x_clean.stride()}", "(8, 1)")
30063006
self.assertEqual(f"{x_clean.shape}", "torch.Size([5, 8])")
30073007

3008+
3009+
def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
3010+
for node in graph.nodes:
3011+
if node.name == "arg3_1":
3012+
assert node.meta["val"].size()[0] == 2
3013+
return graph
3014+
3015+
3016+
class TestUnbacked(TestCase):
30083017
@torch._dynamo.config.patch("capture_scalar_outputs", True)
30093018
def test_deferred_neq_assert(self):
30103019
@torch.compile(fullgraph=True)
@@ -3017,6 +3026,38 @@ def func(a):
30173026
with self.assertRaises(RuntimeError):
30183027
func(torch.tensor([5]))
30193028

3029+
# Test a situation where we generate a runtime assert i.e: u1==s1, then we spcialize s1
3030+
# later on to a constant.
3031+
@torch._dynamo.config.patch("capture_scalar_outputs", True)
3032+
def test_post_specialize_runtime_assert1(self):
3033+
@torch.compile(dynamic=True)
3034+
def func(x, y):
3035+
u0 = y.item()
3036+
s0 = x.size()[0]
3037+
s1 = x.size()[1]
3038+
torch._check(u0 + s0 + s1 == 102)
3039+
assert s0 == 2
3040+
return x * 10
3041+
3042+
func(torch.rand(2, 50), torch.tensor([50]))
3043+
with self.assertRaises(RuntimeError):
3044+
func(torch.rand(2, 50), torch.tensor([51]))
3045+
3046+
@torch._dynamo.config.patch("capture_scalar_outputs", True)
3047+
@torch._inductor.config.patch(post_grad_custom_pre_pass=custom_pass)
3048+
def test_post_specialize_runtime_assert2(self):
3049+
@torch.compile(dynamic=True)
3050+
def func(x, y):
3051+
u0 = y.item()
3052+
s0 = x.size()[0]
3053+
s1 = x.size()[1]
3054+
torch._check(u0 + s0 + s1 == 102)
3055+
return x * 10
3056+
3057+
func(torch.rand(2, 50), torch.tensor([50]))
3058+
with self.assertRaises(RuntimeError):
3059+
func(torch.rand(2, 50), torch.tensor([51]))
3060+
30203061
@torch._dynamo.config.patch("capture_scalar_outputs", True)
30213062
def test_deferred_sym_or_assert(self):
30223063
@torch.compile(fullgraph=True)

torch/_dynamo/output_graph.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,17 +1675,25 @@ def remove_unused_get_attr_nodes(self) -> None:
16751675
self.remove_node(node)
16761676

16771677
def remove_unused_graphargs(self) -> None:
1678-
# NB: It's always OK to drop GraphArg for symbols that ended up being
1679-
# specialized. You don't even have to make a guard for it, because
1680-
# ShapeEnv produce_guards operates on tracked_fakes, which never gets
1681-
# pruned. That being said, you'll get marginally better generated
1678+
# NB: It's OK to drop GraphArg for symbols that ended up being
1679+
# specialized iff they are not used in runtime assertions. You don't
1680+
# even have to make a guard for it, because ShapeEnv produce_guards
1681+
# operates on tracked_fakes, which never gets pruned.
1682+
# That being said, you'll get marginally better generated
16821683
# guard code if you promote the guard into a Dynamo guard (since that
16831684
# allows for the guard to be done using C++ guards.) If we get
16841685
# ShapeEnv guards to go into C++ guards, this will stop being a thing
16851686
# though!
16861687

16871688
assert self.should_exit
16881689

1690+
# # Preserve all symbols that appears in original expressions of a deferred_runtime_asserts.
1691+
# # as place holders.
1692+
# ras_symbols : set[sympy.Symbol] = set()
1693+
# 67E6 for assertion_list in self.shape_env.deferred_runtime_asserts.values():
1694+
# for assertion in assertion_list:
1695+
# ras_symbols |= free_symbols(assertion.expr)
1696+
16891697
# Miniature DCE pass, but only for obviously trivial operations
16901698
def is_static_true(b_node: fx.node.Argument):
16911699
if b_node is True:
@@ -1813,11 +1821,6 @@ def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
18131821
)
18141822
update_used_symbols(used_symbols, fake)
18151823

1816-
# Preserve all symbols that appears in original expressions of a deferred_runtime_asserts.
1817-
for assertion_list in self.shape_env.deferred_runtime_asserts.values():
1818-
for assertion in assertion_list:
1819-
used_symbols |= free_symbols(assertion.expr)
1820-
18211824
# After removing unused graphargs, prune unused binds_symbol
18221825
for node in recheck_placeholders:
18231826
symbol = placeholder_binds_symbol(node)

torch/_inductor/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,9 @@ def placeholder(
10611061
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
10621062
target = self.qualify_name(target)
10631063
if isinstance(example, SymTypes):
1064-
expr = example.node.expr
1064+
# if an input node represents a runtime assertion, use the original symbol and do not
1065+
# apply replacement on it if it was used in rutime assertions(TODO update the PR to add that part).
1066+
expr = example.node._expr
10651067
self.graph_inputs[target] = expr
10661068
self.graph_input_names.append(target)
10671069
return expr

torch/fx/experimental/symbolic_shapes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5942,10 +5942,13 @@ def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None:
59425942

59435943
@_lru_cache
59445944
def replace(self, expr: _SympyT) -> _SympyT:
5945-
"""Apply symbol replacements to any symbols in the given expression"""
5945+
"""
5946+
Apply symbol replacements to any symbols in the given expression.
5947+
"""
59465948
replacements = {}
59475949
for s in expr.free_symbols:
59485950
r = self._find(s)
5951+
59495952
# Micro-optimization: only do replacements if r and s are different
59505953
# Otherwise, xreplace is not a no-op and will trigger expensive
59515954
# assumption queries if expr has a relational node.

torch/fx/passes/runtime_assert.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def add_runtime_asserts(ras):
258258
# nodes
259259
with _set_node_metadata_hook(gm, _node_metadata_hook):
260260
res = _sympy_interp(expr_to_proxy, ra.expr).node
261+
261262
graph.call_function(
262263
torch.ops.aten._assert_scalar.default,
263264
# TODO: use ra.msg here, but it's pretty
@@ -290,7 +291,10 @@ def match_symbol(symint, cb):
290291
if (
291292
isinstance(symint, torch.SymInt)
292293
and isinstance(symint.node, SymNode)
293-
and isinstance(s := symint.node.expr, sympy.Symbol)
294+
# Access the original expr hence _expr and not expr. Since specializations can
295+
# override expr with specialized integers.
296+
# using expr will result in missing some runtime assertions in the graph.
297+
and isinstance(s := symint.node._expr, sympy.Symbol)
294298
and s not in expr_to_proxy
295299
):
296300
with _set_node_metadata_hook(gm, _node_metadata_hook):

0 commit comments

Comments
 (0)
0