@@ -2320,11 +2320,18 @@ def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T: # type: ignore[mi
2320
2320
# and is exclusively used for things that MUST be true (unlike guards,
2321
2321
# which can evaluate False, in which case you just choose not to use
2322
2322
# a particular specialization)
2323
- @dataclass (frozen = True )
2323
+ @dataclass ()
2324
2324
class RuntimeAssert :
2325
- expr : SympyBoolean
2325
+ _expr : SympyBoolean
2326
2326
msg : str = field (repr = False )
2327
2327
stack : CapturedTraceback = field (repr = False )
2328
+ shape_env : ShapeEnv
2329
+
2330
+ @property
2331
+ def expr (self ) -> SympyBoolean :
2332
+ # Whenever we access expr we want to replace specialized backed symbols with their corresponding
2333
+ # specialized values.
2334
+ return self .shape_env .replace (self ._expr , only_specialized_backed = True )
2328
2335
2329
2336
2330
2337
# Used for printing SymExprs in compile_fx
@@ -4483,6 +4490,12 @@ def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
4483
4490
"""Check if a sympy symbol matches the naming convention for unbacked symbols"""
4484
4491
return symbol_is_type (symbol , SymT .UNBACKED_INT )
4485
4492
4493
+ def is_unbacked (self , symbol : sympy .Symbol ) -> bool :
4494
+ """Check if a sympy symbol matches the naming convention for unbacked symbols"""
4495
+ return symbol_is_type (symbol , SymT .UNBACKED_INT ) or symbol_is_type (
4496
+ symbol , SymT .UNBACKED_FLOAT
4497
+ )
4498
+
4486
4499
@record_shapeenv_event ()
4487
4500
def create_unbacked_symbool (self ) -> SymBool :
4488
4501
"""Create a symbolic boolean without a hint value"""
@@ -5941,11 +5954,22 @@ def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None:
5941
5954
return r
5942
5955
5943
5956
@_lru_cache
5944
- def replace (self , expr : _SympyT ) -> _SympyT :
5945
- """Apply symbol replacements to any symbols in the given expression"""
5957
+ def replace (self , expr : _SympyT , only_specialized_backed : bool = False ) -> _SympyT :
5958
+ """
5959
+ Apply symbol replacements to any symbols in the given expression
5960
+ If only_specialized_backed is set, only repalce backed symbols that
5961
+ got specialized i.e: s0:10.
5962
+ """
5946
5963
replacements = {}
5947
5964
for s in expr .free_symbols :
5965
+ if only_specialized_backed and self .is_unbacked (s ):
5966
+ continue
5967
+
5948
5968
r = self ._find (s )
5969
+
5970
+ if only_specialized_backed and not (isinstance (r , (int , float ))):
5971
+ continue
5972
+
5949
5973
# Micro-optimization: only do replacements if r and s are different
5950
5974
# Otherwise, xreplace is not a no-op and will trigger expensive
5951
5975
# assumption queries if expr has a relational node.
@@ -6380,6 +6404,9 @@ def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None:
6380
6404
},
6381
6405
)
6382
6406
6407
+ # when specializing to a constant, runtime assertions that uses a, need to
6408
+ # be updated to remove a from them, since a will no longer be a graph input.
6409
+
6383
6410
for source in self .var_to_sources .get (a , []):
6384
6411
if user_tb :
6385
6412
self .specialization_stacks [source ] = user_tb
@@ -7293,7 +7320,7 @@ def defer_runtime_assert(
7293
7320
orig_expr = expr
7294
7321
expr = canonicalize_bool_expr (expr )
7295
7322
stack = CapturedTraceback .extract (skip = 1 )
7296
- ra = RuntimeAssert (expr , msg , stack )
7323
+ ra = RuntimeAssert (expr , msg , stack , self )
7297
7324
# TODO: Do this in a way that is less janky than int(s.name[1:])
7298
7325
cands = sorted (
7299
7326
(s for s in expr .free_symbols if symbol_is_type (s , SymT .UNBACKED_INT )),
0 commit comments