8000 Revert "Do not generate long log messaged for suppressed data depende… · pytorch/pytorch@bc6c0bc · GitHub
[go: up one dir, main page]

Skip to content

Commit bc6c0bc

Browse files
Revert "Do not generate long log messaged for suppressed data dependent errors. (#151023)"
This reverts commit dfdf731. Reverted #151023 on behalf of https://github.com/laithsakka due to breaking other PRs ([comment](#151023 (comment)))
1 parent 459c62e commit bc6c0bc

File tree

1 file changed

+42
-80
lines changed

1 file changed

+42
-80
lines changed

torch/fx/experimental/symbolic_shapes.py

Lines changed: 42 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import abc
1616
import atexit
1717
import collections
18-
import contextlib
1918
import dis
2019
import functools
2120
import hashlib
@@ -1199,39 +1198,30 @@ def guard_or_false(a: BoolLikeType) -> bool:
11991198
"""
12001199
Try to guard a, if data dependent error encountered just return false.
12011200
"""
1202-
if not isinstance(a, SymBool):
1203-
assert isinstance(a, bool)
1204-
return a
1205-
1206-
with a.node.shape_env.dde_suppressed():
1207-
if torch.fx.experimental._config.backed_size_oblivious:
1208-
return statically_known_true(a)
1209-
else:
1210-
try:
1211-
return guard_bool(a)
1212-
except GuardOnDataDependentSymNode:
1213-
return False
1201+
if torch.fx.experimental._config.backed_size_oblivious:
1202+
return statically_known_true(a)
1203+
else:
1204+
try:
1205+
return bool(guard_bool(a))
1206+
except GuardOnDataDependentSymNode:
1207+
return False
12141208

12151209

12161210
def guard_or_true(a: BoolLikeType) -> bool:
12171211
"""
12181212
Try to guard a, if data dependent error encountered just return true.
12191213
"""
1220-
if not isinstance(a, SymBool):
1221-
assert isinstance(a, bool)
1222-
return a
1223-
1224-
with a.node.shape_env.dde_suppressed():
1225-
if torch.fx.experimental._config.backed_size_oblivious:
1226-
result = _static_eval_sym_bool(a)
1227-
if result is not None:
1228-
return result
1229-
return True
1214+
if torch.fx.experimental._config.backed_size_oblivious:
1215+
result = _static_eval(a)
1216+
if result is not None:
1217+
return result
12301218
else:
1231-
try:
1232-
return guard_bool(a)
1233-
except GuardOnDataDependentSymNode:
1234-
return True
1219+
return True
1220+
else:
1221+
try:
1222+
return bool(guard_bool(a))
1223+
except GuardOnDataDependentSymNode:
1224+
return True
12351225

12361226

12371227
def definitely_true(a: BoolLikeType) -> bool:
@@ -1276,19 +1266,21 @@ def definitely_false(a: BoolLikeType) -> bool:
12761266
return not bool(a)
12771267

12781268

1279-
def _static_eval_sym_bool(x: SymBool) -> Optional[bool]:
1280-
assert isinstance(x, SymBool)
1281-
expr = x.node.expr
1282-
shape_env = x.node.shape_env
1283-
try:
1284-
simplified = shape_env._maybe_evaluate_static(expr)
1285-
if simplified is not None:
1286-
return bool(simplified)
1287-
else:
1269+
def _static_eval(x: Union[bool, SymBool]) -> Optional[bool]:
1270+
if isinstance(x, SymBool):
1271+
expr = x.node.expr
1272+
shape_env = x.node.shape_env
1273+
try:
1274+
simplified = shape_env._maybe_evaluate_static(expr)
1275+
if simplified is not None:
1276+
return bool(simplified)
1277+
else:
1278+
return None
1279+
except Exception:
1280+
log.debug("Could not simplify %s", expr)
12881281
return None
1289-
except Exception:
1290-
log.debug("Could not simplify %s", expr)
1291-
return None
1282+
assert isinstance(x, bool)
1283+
return x
12921284

12931285

12941286
def statically_known_true(x: Union[bool, SymBool]) -> bool:
@@ -1302,15 +1294,11 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool:
13021294
Args:
13031295
x (bool, SymBool): The expression to try statically evaluating
13041296
"""
1305-
if not isinstance(x, SymBool):
1306-
assert isinstance(x, bool)
1307-
return x
1308-
1309-
result = _static_eval_sym_bool(x)
1297+
result = _static_eval(x)
13101298
if result is None:
13111299
return False
1312-
1313-
return result
1300+
else:
1301+
return result
13141302

13151303

13161304
def sym_and(
@@ -3287,10 +3275,6 @@ def __init__(
32873275
else []
32883276
)
32893277

3290-
# Set true when data dependent errors are handled by caller side and not thrown. Ex: guard_or_false
3291-
# and guard_or_true. When its true, a different error message is produced.
3292-
self._dde_suppressed = False
3293-
32943278
# FakeTensor per-ShapeEnv operation cache. This is used for caching
32953279
# operations that contain symbolic shapes which have guards on the
32963280
# ShapeEnv (so are ShapeEnv-dependent).
@@ -3303,18 +3287,6 @@ def __init__(
33033287
torch._subclasses.fake_tensor._DispatchCacheEntry,
33043288
] = {}
33053289

3306-
@contextlib.contextmanager
3307-
def dde_suppressed(self) -> Iterator[None]:
3308-
"""Suppressed GuardOnDataDependent error logs"""
3309-
3310-
# We do not expect this to be called recursively.
3311-
assert not self._dde_suppressed, "not expected value for _dde_suppressed"
3312-
self._dde_suppressed = True
3313-
try:
3314-
yield
3315-
finally:
3316-
self._dde_suppressed = False
3317-
33183290
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
33193291
# tests. Accept their output with:
33203292
#
@@ -3614,7 +3586,6 @@ def check_equal(self, other: ShapeEnv) -> None:
36143586
"replacements_slocs",
36153587
"_resimplify_floor_div_axioms",
36163588
"_expr_sym_node_id",
3617-
"_dde_suppressed",
36183589
)
36193590

36203591
# Mapping of the value of each to-be-compared field into the values that
@@ -6115,12 +6086,6 @@ def _make_data_dependent_error(
61156086
size_oblivious_result: Optional[sympy.Basic] = None,
61166087
expr_sym_node_id: Optional[int] = None,
61176088
) -> GuardOnDataDependentSymNode:
6118-
if self._dde_suppressed:
6119-
return GuardOnDataDependentSymNode(
6120-
expr,
6121-
"This data dependent error is suppressed and handled by the caller",
6122-
)
6123-
61246089
# TODO: in a Dynamo context, having user code, and having the
61256090
# name of the local, will be much better
61266091
size_like_symbols = []
@@ -6833,17 +6798,14 @@ def evaluate_expr(
68336798
size_oblivious,
68346799
forcing_spec=forcing_spec,
68356800
)
6836-
except Exception as e:
6837-
if isinstance(e, GuardOnDataDependentSymNode) and self._dde_suppressed:
6838-
pass
6839-
else:
6840-
self.log.warning(
6841-
"failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s",
6842-
orig_expr,
6843-
hint,
6844-
size_oblivious,
6845-
forcing_spec,
6846-
)
6801+
except Exception:
6802+
self.log.warning(
6803+
"failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s",
6804+
orig_expr,
6805+
hint,
6806+
size_oblivious,
6807+
forcing_spec,
6808+
)
68476809
raise
68486810

68496811
def _evaluate_expr(

0 commit comments

Comments
 (0)
0