15
15
import abc
16
16
import atexit
17
17
import collections
18
- import contextlib
19
18
import dis
20
19
import functools
21
20
import hashlib
@@ -1199,39 +1198,30 @@ def guard_or_false(a: BoolLikeType) -> bool:
1199
1198
"""
1200
1199
Try to guard a, if data dependent error encountered just return false.
1201
1200
"""
1202
- if not isinstance (a , SymBool ):
1203
- assert isinstance (a , bool )
1204
- return a
1205
-
1206
- with a .node .shape_env .dde_suppressed ():
1207
- if torch .fx .experimental ._config .backed_size_oblivious :
1208
- return statically_known_true (a )
1209
- else :
1210
- try :
1211
- return guard_bool (a )
1212
- except GuardOnDataDependentSymNode :
1213
- return False
1201
+ if torch .fx .experimental ._config .backed_size_oblivious :
1202
+ return statically_known_true (a )
1203
+ else :
1204
+ try :
1205
+ return bool (guard_bool (a ))
1206
+ except GuardOnDataDependentSymNode :
1207
+ return False
1214
1208
1215
1209
1216
1210
def guard_or_true (a : BoolLikeType ) -> bool :
1217
1211
"""
1218
1212
Try to guard a, if data dependent error encountered just return true.
1219
1213
"""
1220
- if not isinstance (a , SymBool ):
1221
- assert isinstance (a , bool )
1222
- return a
1223
-
1224
- with a .node .shape_env .dde_suppressed ():
1225
- if torch .fx .experimental ._config .backed_size_oblivious :
1226
- result = _static_eval_sym_bool (a )
1227
- if result is not None :
1228
- return result
1229
- return True
1214
+ if torch .fx .experimental ._config .backed_size_oblivious :
1215
+ result = _static_eval (a )
1216
+ if result is not None :
1217
+ return result
1230
1218
else :
1231
- try :
1232
- return guard_bool (a )
1233
- except GuardOnDataDependentSymNode :
1234
- return True
1219
+ return True
1220
+ else :
1221
+ try :
1222
+ return bool (guard_bool (a ))
1223
+ except GuardOnDataDependentSymNode :
1224
+ return True
1235
1225
1236
1226
1237
1227
def definitely_true (a : BoolLikeType ) -> bool :
@@ -1276,19 +1266,21 @@ def definitely_false(a: BoolLikeType) -> bool:
1276
1266
return not bool (a )
1277
1267
1278
1268
1279
- def _static_eval_sym_bool (x : SymBool ) -> Optional [bool ]:
1280
- assert isinstance (x , SymBool )
1281
- expr = x .node .expr
1282
- shape_env = x .node .shape_env
1283
- try :
1284
- simplified = shape_env ._maybe_evaluate_static (expr )
1285
- if simplified is not None :
1286
- return bool (simplified )
1287
- else :
1269
+ def _static_eval (x : Union [bool , SymBool ]) -> Optional [bool ]:
1270
+ if isinstance (x , SymBool ):
1271
+ expr = x .node .expr
1272
+ shape_env = x .node .shape_env
1273
+ try :
1274
+ simplified = shape_env ._maybe_evaluate_static (expr )
1275
+ if simplified is not None :
1276
+ return bool (simplified )
1277
+ else :
1278
+ return None
1279
+ except Exception :
1280
+ log .debug ("Could not simplify %s" , expr )
1288
1281
return None
1289
- except Exception :
1290
- log .debug ("Could not simplify %s" , expr )
1291
- return None
1282
+ assert isinstance (x , bool )
1283
+ return x
1292
1284
1293
1285
1294
1286
def statically_known_true (x : Union [bool , SymBool ]) -> bool :
@@ -1302,15 +1294,11 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool:
1302
1294
Args:
1303
1295
x (bool, SymBool): The expression to try statically evaluating
1304
1296
"""
1305
- if not isinstance (x , SymBool ):
1306
- assert isinstance (x , bool )
1307
- return x
1308
-
1309
- result = _static_eval_sym_bool (x )
1297
+ result = _static_eval (x )
1310
1298
if result is None :
1311
1299
return False
1312
-
1313
- return result
1300
+ else :
1301
+ return result
1314
1302
1315
1303
1316
1304
def sym_and (
F438
div>
@@ -3287,10 +3275,6 @@ def __init__(
3287
3275
else []
3288
3276
)
3289
3277
3290
- # Set true when data dependent errors are handled by caller side and not thrown. Ex: guard_or_false
3291
- # and guard_or_true. When its true, a different error message is produced.
3292
- self ._dde_suppressed = False
3293
-
3294
3278
# FakeTensor per-ShapeEnv operation cache. This is used for caching
3295
3279
# operations that contain symbolic shapes which have guards on the
3296
3280
# ShapeEnv (so are ShapeEnv-dependent).
@@ -3303,18 +3287,6 @@ def __init__(
3303
3287
torch ._subclasses .fake_tensor ._DispatchCacheEntry ,
3304
3288
] = {}
3305
3289
3306
- @contextlib .contextmanager
3307
- def dde_suppressed (self ) -> Iterator [None ]:
3308
- """Suppressed GuardOnDataDependent error logs"""
3309
-
3310
- # We do not expect this to be called recursively.
3311
- assert not self ._dde_suppressed , "not expected value for _dde_suppressed"
3312
- self ._dde_suppressed = True
3313
- try :
3314
- yield
3315
- finally :
3316
- self ._dde_suppressed = False
3317
-
3318
3290
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
3319
3291
# tests. Accept their output with:
3320
3292
#
@@ -3614,7 +3586,6 @@ def check_equal(self, other: ShapeEnv) -> None:
3614
3586
"replacements_slocs" ,
3615
3587
"_resimplify_floor_div_axioms" ,
3616
3588
"_expr_sym_node_id" ,
3617
- "_dde_suppressed" ,
3618
3589
)
3619
3590
3620
3591
# Mapping of the value of each to-be-compared field into the values that
@@ -6115,12 +6086,6 @@ def _make_data_dependent_error(
6115
6086
size_oblivious_result : Optional [sympy .Basic ] = None ,
6116
6087
expr_sym_node_id : Optional [int ] = None ,
6117
6088
) -> GuardOnDataDependentSymNode :
6118
- if self ._dde_suppressed :
6119
- return GuardOnDataDependentSymNode (
6120
- expr ,
6121
- "This data dependent error is suppressed and handled by the caller" ,
6122
- )
6123
-
6124
6089
# TODO: in a Dynamo context, having user code, and having the
6125
6090
# name of the local, will be much better
6126
6091
size_like_symbols = []
@@ -6833,17 +6798,14 @@ def evaluate_expr(
6833
6798
size_oblivious ,
6834
6799
forcing_spec = forcing_spec ,
6835
6800
)
6836
- except Exception as e :
6837
- if isinstance (e , GuardOnDataDependentSymNode ) and self ._dde_suppressed :
6838
- pass
6839
- else :
6840
- self .log .warning (
6841
- "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s" ,
6842
- orig_expr ,
6843
- hint ,
6844
- size_oblivious ,
6845
- forcing_spec ,
6846
- )
6801
+ except Exception :
6802
+ self .log .warning (
6803
+ "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s" ,
6804
+ orig_expr ,
6805
+ hint ,
6806
+ size_oblivious ,
6807
+ forcing_spec ,
6808
+ )
6847
6809
raise
6848
6810
6849
6811
def _evaluate_expr (
0 commit comments