15
15
import abc
16
16
import atexit
17
17
import collections
18
+ import contextlib
18
19
import dis
19
20
import functools
20
21
import hashlib
@@ -1198,30 +1199,39 @@ def guard_or_false(a: BoolLikeType) -> bool:
1198
1199
"""
1199
1200
Try to guard a, if data dependent error encountered just return false.
1200
1201
"""
1201
- if torch .fx .experimental ._config .backed_size_oblivious :
1202
- return statically_known_true (a )
1203
- else :
1204
- try :
1205
- return bool (guard_bool (a ))
1206
- except GuardOnDataDependentSymNode :
1207
- return False
1202
+ if not isinstance (a , SymBool ):
1203
+ assert isinstance (a , bool )
1204
+ return a
1205
+
1206
+ with a .node .shape_env .dde_suppressed ():
1207
+ if torch .fx .experimental ._config .backed_size_oblivious :
1208
+ return statically_known_true (a )
1209
+ else :
1210
+ try :
1211
+ return guard_bool (a )
1212
+ except GuardOnDataDependentSymNode :
1213
+ return False
1208
1214
1209
1215
1210
1216
def guard_or_true (a : BoolLikeType ) -> bool :
1211
1217
"""
1212
1218
Try to guard a, if data dependent error encountered just return true.
1213
1219
"""
1214
- if torch .fx .experimental ._config .backed_size_oblivious :
1215
- result = _static_eval (a )
1216
- if result is not None :
1217
- return result
1218
- else :
1219
- return True
1220
- else :
1221
- try :
1222
- return bool (guard_bool (a ))
1223
- except GuardOnDataDependentSymNode :
1220
+ if not isinstance (a , SymBool ):
1221
+ assert isinstance (a , bool )
1222
+ return a
1223
+
1224
+ with a .node .shape_env .dde_suppressed ():
1225
+ if torch .fx .experimental ._config .backed_size_oblivious :
1226
+ result = _static_eval_sym_bool (a )
1227
+ if result is not None :
1228
+ return result
1224
1229
return True
1230
+ else :
1231
+ try :
1232
+ return guard_bool (a )
1233
+ except GuardOnDataDependentSymNode :
1234
+ return True
1225
1235
1226
1236
1227
1237
def definitely_true (a : BoolLikeType ) -> bool :
@@ -1266,21 +1276,19 @@ def definitely_false(a: BoolLikeType) -> bool:
1266
1276
return not bool (a )
1267
1277
1268
1278
1269
- def _static_eval (x : Union [bool , SymBool ]) -> Optional [bool ]:
1270
- if isinstance (x , SymBool ):
1271
- expr = x .node .expr
1272
- shape_env = x .node .shape_env
1273
- try :
1274
- simplified = shape_env ._maybe_evaluate_static (expr )
1275
- if simplified is not None :
1276
- return bool (simplified )
1277
- else :
1278
- return None
1279
- except Exception :
1280
- log .debug ("Could not simplify %s" , expr )
1279
+ def _static_eval_sym_bool (x : SymBool ) -> Optional [bool ]:
1280
+ assert isinstance (x , SymBool )
1281
+ expr = x .node .expr
1282
+ shape_env = x .node .shape_env
1283
+ try :
1284
+ simplified = shape_env ._maybe_evaluate_static (expr )
1285
+ if simplified is not None :
1286
+ return bool (simplified )
1287
+ else :
1281
1288
return None
1282
- assert isinstance (x , bool )
1283
- return x
1289
+ except Exception :
1290
+ log .debug ("Could not simplify %s" , expr )
1291
+ return None
1284
1292
1285
1293
1286
1294
def statically_known_true (x : Union [bool , SymBool ]) -> bool :
@@ -1294,11 +1302,15 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool:
1294
1302
Args:
1295
1303
x (bool, SymBool): The expression to try statically evaluating
1296
1304
"""
1297
- result = _static_eval (x )
1305
+ if not isinstance (x , SymBool ):
1306
+ assert isinstance (x , bool )
1307
+ return x
1308
+
1309
+ result = _static_eval_sym_bool (x )
1298
1310
if result is None :
1299
1311
return False
1300
- else :
1301
- return result
1312
+
1313
+ return result
1302
1314
1303
1315
1304
1316
def sym_and (
@@ -3275,6 +3287,10 @@ def __init__(
3275
3287
else []
3276
3288
)
3277
3289
3290
+ # Set true when data dependent errors are handled by caller side and not thrown. Ex: guard_or_false
3291
+ # and guard_or_true. When its true, a different error message is produced.
3292
+ self ._dde_suppressed = False
3293
+
3278
3294
# FakeTensor per-ShapeEnv operation cache. This is used for caching
3279
3295
# operations that contain symbolic shapes which have guards on the
3280
3296
# ShapeEnv (so are ShapeEnv-dependent).
@@ -3287,6 +3303,18 @@ def __init__(
3287
3303
torch ._subclasses .fake_tensor ._DispatchCacheEntry ,
3288
3304
] = {}
3289
3305
3306
+ @contextlib .contextmanager
3307
+ def dde_suppressed (self ) -> Iterator [None ]:
3308
+ """Suppressed GuardOnDataDependent error logs"""
3309
+
3310
+ # We do not expect this to be called recursively.
3311
+ assert not self ._dde_suppressed , "not expected value for _dde_suppressed"
3312
+ self ._dde_suppressed = True
3313
+ try :
3314
+ yield
3315
+ finally :
3316
+ self ._dde_suppressed = False
3317
+
3290
3318
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
3291
3319
# tests. Accept their output with:
3292
3320
#
@@ -3586,6 +3614,7 @@ def check_equal(self, other: ShapeEnv) -> None:
3586
3614
"replacements_slocs" ,
3587
3615
"_resimplify_floor_div_axioms" ,
3588
3616
"_expr_sym_node_id" ,
3617
+ "_dde_suppressed" ,
3589
3618
)
3590
3619
3591
3620
# Mapping of the value of each to-be-compared field into the values that
@@ -6087,6 +6116,12 @@ def _make_data_dependent_error(
6087
6116
size_oblivious_result : Optional [sympy .Basic ] = None ,
6088
6117
expr_sym_node_id : Optional [int ] = None ,
6089
6118
) -> GuardOnDataDependentSymNode :
6119
+ if self ._dde_suppressed :
6120
+ return GuardOnDataDependentSymNode (
6121
+ expr ,
6122
+ "This data dependent error is suppressed and handled by the caller" ,
6123
+ )
6124
+
6090
6125
# TODO: in a Dynamo context, having user code, and having the
6091
6126
# name of the local, will be much better
6092
6127
size_like_symbols = []
@@ -6799,14 +6834,17 @@ def evaluate_expr(
6799
6834
size_oblivious ,
6800
6835
forcing_spec = forcing_spec ,
6801
6836
)
6802
- except Exception :
6803
- self .log .warning (
6804
- "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s" ,
6805
- orig_expr ,
6806
- hint ,
6807
- size_oblivious ,
6808
- forcing_spec ,
6809
- )
6837
+ except Exception as e :
6838
+ if isinstance (e , GuardOnDataDependentSymNode ) and self ._dde_suppressed :
6839
+ pass
6840
+ else :
6841
+ self .log .warning (
6842
+ "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s" ,
6843
+ orig_expr ,
6844
+ hint ,
6845
+ size_oblivious ,
6846
+ forcing_spec ,
6847
+ )
6810
6848
raise
6811
6849
6812
6850
def _evaluate_expr (
0 commit comments