8000 fix specialized symbols after runtime assertions added · pytorch/pytorch@4c53723 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4c53723

Browse files
committed
fix specialized symbols after runtime assertions added
ghstack-source-id: b401169 Pull Request resolved: #153661
1 parent 004dad4 commit 4c53723

File tree

3 files changed

+77
-6
lines changed

3 files changed

+77
-6
lines changed

test/test_dynamic_shapes.py

+41
Original file line numberDiff line numberDiff line change
@@ -3005,6 +3005,15 @@ def test_remove_symbols_without_guarding(self):
30053005
self.assertEqual(f"{x_clean.stride()}", "(8, 1)")
30063006
self.assertEqual(f"{x_clean.shape}", "torch.Size([5, 8])")
30073007

3008+
3009+
def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
3010+
for node in graph.nodes:
3011+
if node.name == "arg3_1":
3012+
assert node.meta["val"].size()[0] == 2
3013+
return graph
3014+
3015+
3016+
class TestUnbacked(TestCase):
30083017
@torch._dynamo.config.patch("capture_scalar_outputs", True)
30093018
def test_deferred_neq_assert(self):
30103019
@torch.compile(fullgraph=True)
@@ -3017,6 +3026,38 @@ def func(a):
30173026
with self.assertRaises(RuntimeError):
30183027
func(torch.tensor([5]))
30193028

3029+
# Test a situation where we generate a runtime assert i.e: u1==s1, then we spcialize s1
3030+
# later on to a constant.
3031+
@torch._dynamo.config.patch("capture_s 10000 calar_outputs", True)
3032+
def test_post_specialize_runtime_assert1(self):
3033+
@torch.compile(dynamic=True)
3034+
def func(x, y):
3035+
u0 = y.item()
3036+
s0 = x.size()[0]
3037+
s1 = x.size()[1]
3038+
torch._check(u0 + s0 + s1 == 102)
3039+
assert s0 == 2
3040+
return x * 10
3041+
3042+
func(torch.rand(2, 50), torch.tensor([50]))
3043+
with self.assertRaises(RuntimeError):
3044+
func(torch.rand(2, 50), torch.tensor([51]))
3045+
3046+
@torch._dynamo.config.patch("capture_scalar_outputs", True)
3047+
@torch._inductor.config.patch(post_grad_custom_pre_pass=custom_pass)
3048+
def test_post_specialize_runtime_assert2(self):
3049+
@torch.compile(dynamic=True)
3050+
def func(x, y):
3051+
u0 = y.item()
3052+
s0 = x.size()[0]
3053+
s1 = x.size()[1]
3054+
torch._check(u0 + s0 + s1 == 102)
3055+
return x * 10
3056+
3057+
func(torch.rand(2, 50), torch.tensor([50]))
3058+
with self.assertRaises(RuntimeError):
3059+
func(torch.rand(2, 50), torch.tensor([51]))
3060+
30203061
@torch._dynamo.config.patch("capture_scalar_outputs", True)
30213062
def test_deferred_sym_or_assert(self):
30223063
@torch.compile(fullgraph=True)

torch/fx/experimental/symbolic_shapes.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -2320,11 +2320,18 @@ def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T: # type: ignore[mi
23202320
# and is exclusively used for things that MUST be true (unlike guards,
23212321
# which can evaluate False, in which case you just choose not to use
23222322
# a particular specialization)
2323-
@dataclass(frozen=True)
2323+
@dataclass()
23242324
class RuntimeAssert:
2325-
expr: SympyBoolean
2325+
_expr: SympyBoolean
23262326
msg: str = field(repr=False)
23272327
stack: CapturedTraceback = field(repr=False)
2328+
shape_env: ShapeEnv
2329+
2330+
@property
2331+
def expr(self) -> SympyBoolean:
2332+
# Whenever we access expr we want to replace specialized backed symbols with their corresponding
2333+
# specialized values.
2334+
return self.shape_env.replace(self._expr, only_specialized_backed=True)
23282335

23292336

23302337
# Used for printing SymExprs in compile_fx
@@ -4483,6 +4490,12 @@ def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
44834490
"""Check if a sympy symbol matches the naming convention for unbacked symbols"""
44844491
return symbol_is_type(symbol, SymT.UNBACKED_INT)
44854492

4493+
def is_unbacked(self, symbol: sympy.Symbol) -> bool:
4494+
"""Check if a sympy symbol matches the naming convention for unbacked symbols"""
4495+
return symbol_is_type(symbol, SymT.UNBACKED_INT) or symbol_is_type(
4496+
symbol, SymT.UNBACKED_FLOAT
4497+
)
4498+
44864499
@record_shapeenv_event()
44874500
def create_unbacked_symbool(self) -> SymBool:
44884501
"""Create a symbolic boolean without a hint value"""
@@ -5941,11 +5954,22 @@ def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None:
59415954
return r
59425955

59435956
@_lru_cache
5944-
def replace(self, expr: _SympyT) -> _SympyT:
5945-
"""Apply symbol replacements to any symbols in the given expression"""
5957+
def replace(self, expr: _SympyT, only_specialized_backed: bool = False) -> _SympyT:
5958+
"""
5959+
Apply symbol replacements to any symbols in the given expression
5960+
If only_specialized_backed is set, only repalce backed symbols that
5961+
got specialized i.e: s0:10.
5962+
"""
59465963
replacements = {}
59475964
for s in expr.free_symbols:
5965+
if only_specialized_backed and self.is_unbacked(s):
5966+
continue
5967+
59485968
r = self._find(s)
5969+
5970+
if only_specialized_backed and not (isinstance(r, (int, float))):
5971+
continue
5972+
59495973
# Micro-optimization: only do replacements if r and s are different
59505974
# Otherwise, xreplace is not a no-op and will trigger expensive
59515975
# assumption queries if expr has a relational node.
@@ -6380,6 +6404,9 @@ def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None:
63806404
},
63816405
)
63826406

6407+
# when specializing to a constant, runtime assertions that uses a, need to
6408+
# be updated to remove a from them, since a will no longer be a graph input.
6409+
63836410
for source in self.var_to_sources.get(a, []):
63846411
if user_tb:
63856412
self.specialization_stacks[source] = user_tb
@@ -7293,7 +7320,7 @@ def defer_runtime_assert(
72937320
orig_expr = expr
72947321
expr = canonicalize_bool_expr(expr)
72957322
stack = CapturedTraceback.extract(skip=1)
7296-
ra = RuntimeAssert(expr, msg, stack)
7323+
ra = RuntimeAssert(expr, msg, stack, self)
72977324
# TODO: Do this in a way that is less janky than int(s.name[1:])
72987325
cands = sorted(
72997326
(s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)),

torch/fx/passes/runtime_assert.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,10 @@ def add_runtime_asserts(ras):
257257
# Convert the sympy expression into a sequence of FX
258258
# nodes
259259
with _set_node_metadata_hook(gm, _node_metadata_hook):
260-
res = _sympy_interp(expr_to_proxy, ra.expr).node
260+
try:
261+
res = _sympy_interp(expr_to_proxy, ra.expr).node
262+
except:
< 5178 /code>263+
raise
261264
graph.call_function(
262265
torch.ops.aten._assert_scalar.default,
263266
# TODO: use ra.msg here, but it's pretty

0 commit comments

Comments
 (0)
0