8000 Do not generate long log messaged for suppressed data dependent error… · pytorch/pytorch@dfdf731 · GitHub
[go: up one dir, main page]

Skip to content

Commit dfdf731

Browse files
laithsakkapytorchmergebot
authored andcommitted
Do not generate long log messaged for suppressed data dependent errors. (#151023)
TORCH_LOGS="all" python test/test_dynamic_shapes.py -k test_guard_or_true before: <img width="1065" alt="Screenshot 2025-04-10 at 9 55 27 AM" src="https://github.com/user-attachments/assets/3ee20de0-2902-4eb1-8ab0-80f1b974fb78" /> after: <img width="1124" alt="Screenshot 2025-04-10 at 9 54 35 AM" src="https://github.com/user-attachments/assets/4e7e1f0c-856c-417f-8763-bfe183e2450d" /> Note: we actually do not expect to see a log at all, this is an orthogonal issue in recording where it logs each error seen even when recording is not enabled? I will follow up with PR for that. Pull Request resolved: #151023 Approved by: https://github.com/bobrenjc93
1 parent a09a3f4 commit dfdf731

File tree

1 file changed

+80
-42
lines changed

1 file changed

+80
-42
lines changed

torch/fx/experimental/symbolic_shapes.py

Lines changed: 80 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import abc
1616
import atexit
1717
import collections
18+
import contextlib
1819
import dis
1920
import functools
2021
import hashlib
@@ -1198,30 +1199,39 @@ def guard_or_false(a: BoolLikeType) -> bool:
11981199
"""
11991200
Try to guard a, if data dependent error encountered just return false.
12001201
"""
1201-
if torch.fx.experimental._config.backed_size_oblivious:
1202-
return statically_known_true(a)
1203-
else:
1204-
try:
1205-
return bool(guard_bool(a))
1206-
except GuardOnDataDependentSymNode:
1207-
return False
1202+
if not isinstance(a, SymBool):
1203+
assert isinstance(a, bool)
1204+
return a
1205+
1206+
with a.node.shape_env.dde_suppressed():
1207+
if torch.fx.experimental._config.backed_size_oblivious:
1208+
return statically_known_true(a)
1209+
else:
1210+
try:
1211+
return guard_bool(a)
1212+
except GuardOnDataDependentSymNode:
1213+
return False
12081214

12091215

12101216
def guard_or_true(a: BoolLikeType) -> bool:
12111217
"""
12121218
Try to guard a, if data dependent error encountered just return true.
12131219
"""
1214-
if torch.fx.experimental._config.backed_size_oblivious:
1215-
result = _static_eval(a)
1216-
if result is not None:
1217-
return result
1218-
else:
1219-
return True
1220-
else:
1221-
try:
1222-
return bool(guard_bool(a))
1223-
except GuardOnDataDependentSymNode:
1220+
if not isinstance(a, SymBool):
1221+
assert isinstance(a, bool)
1222+
return a
1223+
1224+
with a.node.shape_env.dde_suppressed():
1225+
if torch.fx.experimental._config.backed_size_oblivious:
1226+
result = _static_eval_sym_bool(a)
1227+
if result is not None:
1228+
return result
12241229
return True
1230+
else:
1231+
try:
1232+
return guard_bool(a)
1233+
except GuardOnDataDependentSymNode:
1234+
return True
12251235

12261236

12271237
def definitely_true(a: BoolLikeType) -> bool:
@@ -1266,21 +1276,19 @@ def definitely_false(a: BoolLikeType) -> bool:
12661276
return not bool(a)
12671277

12681278

1269-
def _static_eval(x: Union[bool, SymBool]) -> Optional[bool]:
1270-
if isinstance(x, SymBool):
1271-
expr = x.node.expr
1272-
shape_env = x.node.shape_env
1273-
try:
1274-
simplified = shape_env._maybe_evaluate_static(expr)
1275-
if simplified is not None:
1276-
return bool(simplified)
1277-
else:
1278-
return None
1279-
except Exception:
1280-
log.debug("Could not simplify %s", expr)
1279+
def _static_eval_sym_bool(x: SymBool) -> Optional[bool]:
1280+
assert isinstance(x, SymBool)
1281+
expr = x.node.expr
1282+
shape_env = x.node.shape_env
1283+
try:
1284+
simplified = shape_env._maybe_evaluate_static(expr)
1285+
if simplified is not None:
1286+
return bool(simplified)
1287+
else:
12811288
return None
1282-
assert isinstance(x, bool)
1283-
return x
1289+
except Exception:
1290+
log.debug("Could not simplify %s", expr)
1291+
return None
12841292

12851293

12861294
def statically_known_true(x: Union[bool, SymBool]) -> bool:
@@ -1294,11 +1302,15 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool:
12941302
Args:
12951303
x (bool, SymBool): The expression to try statically evaluating
12961304
"""
1297-
result = _static_eval(x)
1305+
if not isinstance(x, SymBool):
1306+
assert isinstance(x, bool)
1307+
return x
1308+
1309+
result = _static_eval_sym_bool(x)
12981310
if result is None:
12991311
return False
1300-
else:
1301-
return result
1312+
1313+
return result
13021314

13031315

13041316
def sym_and(
@@ -3275,6 +3287,10 @@ def __init__(
32753287
else []
32763288
)
32773289

3290+
# Set true when data dependent errors are handled by caller side and not thrown. Ex: guard_or_false
3291+
# and guard_or_true. When its true, a different error message is produced.
3292+
self._dde_suppressed = False
3293+
32783294
# FakeTensor per-ShapeEnv operation cache. This is used for caching
32793295
# operations that contain symbolic shapes which have guards on the
32803296
# ShapeEnv (so are ShapeEnv-dependent).
@@ -3287,6 +3303,18 @@ def __init__(
32873303
torch._subclasses.fake_tensor._DispatchCacheEntry,
32883304
] = {}
32893305

3306+
@contextlib.contextmanager
3307+
def dde_suppressed(self) -> Iterator[None]:
3308+
"""Suppressed GuardOnDataDependent error logs"""
3309+
3310+
# We do not expect this to be called recursively.
3311+
assert not self._dde_suppressed, "not expected value for _dde_suppressed"
3312+
self._dde_suppressed = True
3313+
try:
3314+
yield
3315+
finally:
3316+
self._dde_suppressed = False
3317+
32903318
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
32913319
# tests. Accept their output with:
32923320
#
@@ -3586,6 +3614,7 @@ def check_equal(self, other: ShapeEnv) -> None:
35863614
"replacements_slocs",
35873615
"_resimplify_floor_div_axioms",
35883616
"_expr_sym_node_id",
3617+
"_dde_suppressed",
35893618
)
35903619

35913620
# Mapping of the value of each to-be-compared field into the values that
@@ -6087,6 +6116,12 @@ def _make_data_dependent_error(
60876116
size_oblivious_result: Optional[sympy.Basic] = None,
60886117
expr_sym_node_id: Optional[int] = None,
60896118
) -> GuardOnDataDependentSymNode:
6119+
if self._dde_suppressed:
6120+
return GuardOnDataDependentSymNode(
6121+
expr,
6122+
"This data dependent error is suppressed and handled by the caller",
6123+
)
6124+
60906125
# TODO: in a Dynamo context, having user code, and having the
60916126
# name of the local, will be much better
60926127
size_like_symbols = []
@@ -6799,14 +6834,17 @@ def evaluate_expr(
67996834
size_oblivious,
68006835
forcing_spec=forcing_spec,
68016836
)
6802-
except Exception:
6803-
self.log.warning(
6804-
"failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s",
6805-
orig_expr,
6806-
hint,
6807-
size_oblivious,
6808-
forcing_spec,
6809-
)
6837+
except Exception as e:
6838+
if isinstance(e, GuardOnDataDependentSymNode) and self._dde_suppressed:
6839+
pass
6840+
else:
6841+
self.log.warning(
6842+
"failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s",
6843+
orig_expr,
6844+
hint,
6845+
size_oblivious,
6846+
forcing_spec,
6847+
)
68106848
raise
68116849

68126850
def _evaluate_expr(

0 commit comments

Comments
 (0)
0