17
17
import abc
18
18
import atexit
19
19
import collections
20
+ import contextlib
20
21
import dis
21
22
import functools
22
23
import hashlib
@@ -308,8 +309,6 @@ def uninteresting_files() -> set[str]:
308
309
torch .fx .experimental .recording ,
309
310
torch .fx .experimental .sym_node ,
310
311
torch .fx .interpreter ,
311
- torch .fx .proxy ,
312
- torch .fx ._symbolic_trace ,
313
312
torch ,
314
313
torch ._compile ,
315
314
torch ._dynamo .eval_frame ,
@@ -1224,34 +1223,41 @@ def compute_unbacked_bindings(
1224
1223
# (1) It's an optimization/additional check I do not want to fail for not performing it.
1225
1224
# (2) I am willing to deviate from the normal semantics when I have unbacked for the
1226
1225
# benefit of not failing.
1226
+ def _guard_or (a : BoolLikeType , default : bool ) -> bool :
1227
+ if not isinstance (a , SymBool ):
1228
+ assert isinstance (a , bool )
1229
+ return a
1230
+
1231
+ # if backed_size_oblivious is True we treat backed as unbacked here.
1232
+ if torch .fx .experimental ._config .backed_size_oblivious :
1233
+ result = _static_eval_sym_bool (a )
1234
+ return result if result is not None else default
1235
+
1236
+ shape_env = getattr (a .node , "shape_env" , None )
1237
+
1238
+ # xla symnode path.
1239
+ if shape_env is None :
1240
+ return guard_bool (a )
1241
+
1242
+ with a .node .shape_env .dde_suppressed ():
1243
+ try :
1244
+ return guard_bool (a )
1245
+ except GuardOnDataDependentSymNode :
1246
+ return default
1247
+
1248
+
1227
1249
def guard_or_false (a : BoolLikeType ) -> bool :
1228
1250
"""
1229
1251
Try to guard a, if data dependent error encountered just return false.
1230
1252
"""
1231
- if torch .fx .experimental ._config .backed_size_oblivious :
1232
- return statically_known_true (a )
1233
- else :
1234
- try :
1235
- return bool (guard_bool (a ))
1236
- except GuardOnDataDependentSymNode :
1237
- return False
1253
+ return _guard_or (a , False )
1238
1254
1239
1255
1240
1256
def guard_or_true (a : BoolLikeType ) -> bool :
1241
1257
"""
1242
1258
Try to guard a, if data dependent error encountered just return true.
1243
1259
"""
1244
- if torch .fx .experimental ._config .backed_size_oblivious :
1245
- result = _static_eval (a )
1246
- if result is not None :
1247
- return result
1248
- else :
1249
- return True
1250
- else :
1251
- try :
1252
- return bool (guard_bool (a ))
1253
- except GuardOnDataDependentSymNode :
1254
- return True
1260
+ return _guard_or (a , True )
1255
1261
1256
1262
1257
1263
def definitely_true (a : BoolLikeType ) -> bool :
@@ -1296,21 +1302,22 @@ def definitely_false(a: BoolLikeType) -> bool:
1296
1302
return not bool (a )
1297
1303
1298
1304
1299
- def _static_eval (x : BoolLikeType ) -> Optional [bool ]:
1300
- if isinstance (x , SymBool ):
1301
- expr = x .node .expr
1305
+ def _static_eval_sym_bool (x : SymBool ) -> Optional [bool ]:
1306
+ assert isinstance (x , SymBool )
1307
+ expr = x .node .expr
1308
+
1309
+ try :
1310
+ # Shape env access is inside the try on purpose. xla symnode does not
1311
+ # have it on its attributes.
1302
1312
shape_env = x .node .shape_env
1303
- try :
1304
- simplified = shape_env ._maybe_evaluate_static (expr )
1305
- if simplified is not None :
1306
- return bool (simplified )
1307
- else :
1308
- return None
1309
- except Exception :
1310
- log .debug ("Could not simplify %s" , expr )
1313
+ simplified = shape_env ._maybe_evaluate_static (expr )
1314
+ if simplified is not None :
1315
+ return bool (simplified )
1316
+ else :
1311
1317
return None
1312
- assert isinstance (x , bool )
1313
- return x
1318
+ except Exception :
1319
+ log .debug ("Could not simplify %s" , expr )
1320
+ return None
1314
1321
1315
1322
1316
1323
def statically_known_true (x : BoolLikeType ) -> bool :
@@ -1324,11 +1331,15 @@ def statically_known_true(x: BoolLikeType) -> bool:
1324
1331
Args:
1325
1332
x (bool, SymBool): The expression to try statically evaluating
1326
1333
"""
1327
- result = _static_eval (x )
1334
+ if not isinstance (x , SymBool ):
1335
+ assert isinstance (x , bool )
1336
+ return x
1337
+
1338
+ result = _static_eval_sym_bool (x )
1328
1339
if result is None :
1329
1340
return False
1330
- else :
1331
- return result
1341
+
1342
+ return result
1332
1343
1333
1344
1334
1345
def sym_and (x : BoolLikeType , * others : BoolLikeType ) -> BoolLikeType :
@@ -3301,6 +3312,10 @@ def __init__(
3301
3312
else []
3302
3313
)
3303
3314
3315
+ # Set true when data dependent errors are handled by caller side and not thrown. Ex: guard_or_false
3316
+ # and guard_or_true. When its true, a different error message is produced.
3317
+ self ._dde_suppressed = False
3318
+
3304
3319
# FakeTensor per-ShapeEnv operation cache. This is used for caching
3305
3320
# operations that contain symbolic shapes which have guards on the
3306
3321
# ShapeEnv (so are ShapeEnv-dependent).
@@ -3313,6 +3328,18 @@ def __init__(
3313
3328
torch ._subclasses .fake_tensor ._DispatchCacheEntry ,
3314
3329
] = {}
3315
3330
3331
+ @contextlib .contextmanager
3332
+ def dde_suppressed (self ) -> Iterator [None ]:
3333
+ """Suppressed GuardOnDataDependent error logs"""
3334
+
3335
+ # We do not expect this to be called recursively.
3336
+ assert not self ._dde_suppressed , "not expected value for _dde_suppressed"
3337
+ self ._dde_suppressed = True
3338
+ try :
3339
+ yield
3340
+ finally :
3341
+ self ._dde_suppressed = False
3342
+
3316
3343
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
3317
3344
# tests. Accept their output with:
3318
3345
#
@@ -3612,6 +3639,7 @@ def check_equal(self, other: ShapeEnv) -> None:
3612
3639
"replacements_slocs" ,
3613
3640
"_resimplify_floor_div_axioms" ,
3614
3641
"_expr_sym_node_id" ,
3642
+ "_dde_suppressed" ,
3615
3643
)
3616
3644
3617
3645
# Mapping of the value of each to-be-compared field into the values that
@@ -5966,7 +5994,6 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT:
5966
5994
min_max_replacements [atom ] = a if atom .func is Max else b
5967
5995
if min_max_replacements :
5968
5996
expr = expr .xreplace (min_max_replacements )
5969
- expr = safe_expand (expr )
5970
5997
5971
5998
# TODO it would seem that this pass is not necessary given the
5972
5999
# below replacement of // with /, but for nested FloorDivs
@@ -6102,6 +6129,12 @@ def _make_data_dependent_error(
6102
6129
size_oblivious_result : Optional [sympy .Basic ] = None ,
6103
6130
expr_sym_node_id : Optional [int ] = None ,
6104
6131
) -> GuardOnDataDependentSymNode :
6132
+ if self ._dde_suppressed :
6133
+ return GuardOnDataDependentSymNode (
6134
+ expr ,
6135
+ "This data dependent error is suppressed and handled by the caller" ,
6136
+ )
6137
+
6105
6138
# TODO: in a Dynamo context, having user code, and having the
6106
6139
# name of the local, will be much better
6107
6140
size_like_symbols = []
@@ -6814,14 +6847,17 @@ def evaluate_expr(
6814
6847
size_oblivious ,
6815
6848
forcing_spec = forcing_spec ,
6816
6849
)
6817
- except Exception :
6818
- self .log .warning (
6819
- "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s" ,
6820
- orig_expr ,
6821
<
99D1
/td>- hint ,
6822
- size_oblivious ,
6823
- forcing_spec ,
6824
- )
6850
+ except Exception as e :
6851
+ if isinstance (e , GuardOnDataDependentSymNode ) and self ._dde_suppressed :
6852
+ pass
6853
+ else :
6854
+ self .log .warning (
6855
+ "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s" ,
6856
+ orig_expr ,
6857
+ hint ,
6858
+ size_oblivious ,
6859
+ forcing_spec ,
6860
+ )
6825
6861
raise
6826
6862
6827
6863
def _evaluate_expr (
0 commit comments