8000 Do not generate long log messages for suppressed data dependent error… · pytorch/pytorch@98bd2bd · GitHub
[go: up one dir, main page]

Skip to content

Commit 98bd2bd

Browse files
laithsakkapytorchmergebot
authored andcommitted
Do not generate long log messages for suppressed data dependent errors. (#151023)
TORCH_LOGS="all" python test/test_dynamic_shapes.py -k test_guard_or_true before: <img width="1065" alt="Screenshot 2025-04-10 at 9 55 27 AM" src="https://github.com/user-attachments/assets/3ee20de0-2902-4eb1-8ab0-80f1b974fb78" /> after: <img width="1124" alt="Screenshot 2025-04-10 at 9 54 35 AM" src="https://github.com/user-attachments/assets/4e7e1f0c-856c-417f-8763-bfe183e2450d" /> Note: we actually do not expect to see a log at all, this is an orthogonal issue in recording where it logs each error seen even when recording is not enabled? I will follow up with PR for that. Pull Request resolved: #151023 Approved by: https://github.com/bobrenjc93
1 parent 70d7638 commit 98bd2bd

File tree

1 file changed

+81
-45
lines changed

1 file changed

+81
-45
lines changed

torch/fx/experimental/symbolic_shapes.py

Lines changed: 81 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import abc
1818
import atexit
1919
import collections
20+
import contextlib
2021
import dis
2122
import functools
2223
import hashlib
@@ -308,8 +309,6 @@ def uninteresting_files() -> set[str]:
308309
torch.fx.experimental.recording,
309310
torch.fx.experimental.sym_node,
310311
torch.fx.interpreter,
311-
torch.fx.proxy,
312-
torch.fx._symbolic_trace,
313312
torch,
314313
torch._compile,
315314
torch._dynamo.eval_frame,
@@ -1224,34 +1223,41 @@ def compute_unbacked_bindings(
12241223
# (1) It's an optimization/additional check I do not want to fail for not performing it.
12251224
# (2) I am willing to deviate from the normal semantics when I have unbacked for the
12261225
# benefit of not failing.
1226+
def _guard_or(a: BoolLikeType, default: bool) -> bool:
1227+
if not isinstance(a, SymBool):
1228+
assert isinstance(a, bool)
1229+
return a
1230+
1231+
# if backed_size_oblivious is True we treat backed as unbacked here.
1232+
if torch.fx.experimental._config.backed_size_oblivious:
1233+
result = _static_eval_sym_bool(a)
1234+
return result if result is not None else default
1235+
1236+
shape_env = getattr(a.node, "shape_env", None)
1237+
1238+
# xla symnode path.
1239+
if shape_env is None:
1240+
return guard_bool(a)
1241+
1242+
with a.node.shape_env.dde_suppressed():
1243+
try:
1244+
return guard_bool(a)
1245+
except GuardOnDataDependentSymNode:
1246+
return default
1247+
1248+
12271249
def guard_or_false(a: BoolLikeType) -> bool:
12281250
"""
12291251
Try to guard a, if data dependent error encountered just return false.
12301252
"""
1231-
if torch.fx.experimental._config.backed_size_oblivious:
1232-
return statically_known_true(a)
1233-
else:
1234-
try:
1235-
return bool(guard_bool(a))
1236-
except GuardOnDataDependentSymNode:
1237-
return False
1253+
return _guard_or(a, False)
12381254

12391255

12401256
def guard_or_true(a: BoolLikeType) -> bool:
12411257
"""
12421258
Try to guard a, if data dependent error encountered just return true.
12431259
"""
1244-
if torch.fx.experimental._config.backed_size_oblivious:
1245-
result = _static_eval(a)
1246-
if result is not None:
1247-
return result
1248-
else:
1249-
return True
1250-
else:
1251-
try:
1252-
return bool(guard_bool(a))
1253-
except GuardOnDataDependentSymNode:
1254-
return True
1260+
return _guard_or(a, True)
12551261

12561262

12571263
def definitely_true(a: BoolLikeType) -> bool:
@@ -1296,21 +1302,22 @@ def definitely_false(a: BoolLikeType) -> bool:
12961302
return not bool(a)
12971303

12981304

1299-
def _static_eval(x: BoolLikeType) -> Optional[bool]:
1300-
if isinstance(x, SymBool):
1301-
expr = x.node.expr
1305+
def _static_eval_sym_bool(x: SymBool) -> Optional[bool]:
1306+
assert isinstance(x, SymBool)
1307+
expr = x.node.expr
1308+
1309+
try:
1310+
# Shape env access is inside the try on purpose. xla symnode does not
1311+
# have it on its attributes.
13021312
shape_env = x.node.shape_env
1303-
try:
1304-
simplified = shape_env._maybe_evaluate_static(expr)
1305-
if simplified is not None:
1306-
return bool(simplified)
1307-
else:
1308-
return None
1309-
except Exception:
1310-
log.debug("Could not simplify %s", expr)
1313+
simplified = shape_env._maybe_evaluate_static(expr)
1314+
if simplified is not None:
1315+
return bool(simplified)
1316+
else:
13111317
return None
1312-
assert isinstance(x, bool)
1313-
return x
1318+
except Exception:
1319+
log.debug("Could not simplify %s", expr)
1320+
return None
13141321

13151322

13161323
def statically_known_true(x: BoolLikeType) -> bool:
@@ -1324,11 +1331,15 @@ def statically_known_true(x: BoolLikeType) -> bool:
13241331
Args:
13251332
x (bool, SymBool): The expression to try statically evaluating
13261333
"""
1327-
result = _static_eval(x)
1334+
if not isinstance(x, SymBool):
1335+
assert isinstance(x, bool)
1336+
return x
1337+
1338+
result = _static_eval_sym_bool(x)
13281339
if result is None:
13291340
return False
1330-
else:
1331-
return result
1341+
1342+
return result
13321343

13331344

13341345
def sym_and(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
@@ -3301,6 +3312,10 @@ def __init__(
33013312
else []
33023313
)
33033314

3315+
# Set true when data dependent errors are handled by caller side and not thrown. Ex: guard_or_false
3316+
# and guard_or_true. When its true, a different error message is produced.
3317+
self._dde_suppressed = False
3318+
33043319
# FakeTensor per-ShapeEnv operation cache. This is used for caching
33053320
# operations that contain symbolic shapes which have guards on the
33063321
# ShapeEnv (so are ShapeEnv-dependent).
@@ -3313,6 +3328,18 @@ def __init__(
33133328
torch._subclasses.fake_tensor._DispatchCacheEntry,
33143329
] = {}
33153330

3331+
@contextlib.contextmanager
3332+
def dde_suppressed(self) -> Iterator[None]:
3333+
"""Suppressed GuardOnDataDependent error logs"""
3334+
3335+
# We do not expect this to be called recursively.
3336+
assert not self._dde_suppressed, "not expected value for _dde_suppressed"
3337+
self._dde_suppressed = True
3338+
try:
3339+
yield
3340+
finally:
3341+
self._dde_suppressed = False
3342+
33163343
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
33173344
# tests. Accept their output with:
33183345
#
@@ -3612,6 +3639,7 @@ def check_equal(self, other: ShapeEnv) -> None:
36123639
"replacements_slocs",
36133640
"_resimplify_floor_div_axioms",
36143641
"_expr_sym_node_id",
3642+
"_dde_suppressed",
36153643
)
36163644

36173645
# Mapping of the value of each to-be-compared field into the values that
@@ -5966,7 +5994,6 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT:
59665994
min_max_replacements[atom] = a if atom.func is Max else b
59675995
if min_max_replacements:
59685996
expr = expr.xreplace(min_max_replacements)
5969-
expr = safe_expand(expr)
59705997

59715998
# TODO it would seem that this pass is not necessary given the
59725999
# below replacement of // with /, but for nested FloorDivs
@@ -6102,6 +6129,12 @@ def _make_data_dependent_error(
61026129
size_oblivious_result: Optional[sympy.Basic] = None,
61036130
expr_sym_node_id: Optional[int] = None,
61046131
) -> GuardOnDataDependentSymNode:
6132+
if self._dde_suppressed:
6133+
return GuardOnDataDependentSymNode(
6134+
expr,
6135+
"This data dependent error is suppressed and handled by the caller",
6136+
)
6137+
61056138
# TODO: in a Dynamo context, having user code, and having the
61066139
# name of the local, will be much better
61076140
size_like_symbols = []
@@ -6814,14 +6847,17 @@ def evaluate_expr(
68146847
size_oblivious,
68156848
forcing_spec=forcing_spec,
68166849
)
6817-
except Exception:
6818-
self.log.warning(
6819-
"failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s",
6820-
orig_expr,
6821< 99D1 /td>-
hint,
6822-
size_oblivious,
6823-
forcing_spec,
6824-
)
6850+
except Exception as e:
6851+
if isinstance(e, GuardOnDataDependentSymNode) and self._dde_suppressed:
6852+
pass
6853+
else:
6854+
self.log.warning(
6855+
"failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s",
6856+
orig_expr,
6857+
hint,
6858+
size_oblivious,
6859+
forcing_spec,
6860+
)
68256861
raise
68266862

68276863
def _evaluate_expr(

0 commit comments

Comments
 (0)
0