8000 include user stacks with constraint violation error message (#152924) · pytorch/pytorch@70c8047 · GitHub
[go: up one dir, main page]

Skip to content

Commit 70c8047

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
include user stacks with constraint violation error message (#152924)
Fixes #152918 Before: ``` File "/data/users/bobren/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 5588, in produce_guards_verbose raise ConstraintViolationError( torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[0])! For more information, run with TORCH_LOGS="+dynamic". - You marked L['x'].size()[0] as dynamic but your code specialized it to be a constant (5). Either remove the mark_dynamic or use a less strict API such as maybe_mark_dynamic or Dim.AUTO. ``` After: ``` File "/data/users/bobren/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 5588, in produce_guards_verbose raise ConstraintViolationError( torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[0])! For more information, run with TORCH_LOGS="+dynamic". - You marked L['x'].size()[0] as dynamic but your code specialized it to be a constant (5). Either remove the mark_dynamic or use a less strict API such as maybe_mark_dynamic or Dim.AUTO. User stack: File "/home/bobren/local/a/pytorch/error.py", line 5, in foo return torch.randn(5) * x ``` Pull Request resolved: #152924 Approved by: https://github.com/pianpwk
1 parent 4c11b26 commit 70c8047

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

torch/fx/experimental/symbolic_shapes.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3542,6 +3542,8 @@ def _init(
35423542
# with something like effect token tracking.
35433543
self.unbacked_alloc_order: dict[sympy.Symbol, int] = {}
35443544

3545+
self.specialization_stacks: dict[Source, traceback.StackSummary] = {}
3546+
35453547
self.trace_asserts = trace_asserts
35463548

35473549
from torch.fx.experimental.validator import translation_validation_enabled
@@ -3624,6 +3626,7 @@ def check_equal(self, other: ShapeEnv) -> None:
36243626
"_resimplify_floor_div_axioms",
36253627
"_expr_sym_node_id",
36263628
"_dde_suppressed",
3629+
"specialization_stacks",
36273630
)
36283631

36293632
# Mapping of the value of each to-be-compared field into the values that
@@ -5159,10 +5162,16 @@ def hint(s: sympy.Expr) -> str:
51595162
var_with_range = self._render_range_for_constraint_violation(
51605163
source, constraint
51615164
)
5165+
user_stack = self.specialization_stacks.get(source, None)
51625166
msg = (
51635167
f"You marked {self._debug_name(source)} as dynamic but your code "
51645168
f"specialized it to be a constant ({val}). Either remove the mark_dynamic "
51655169
f"or use a less strict API such as maybe_mark_dynamic or Dim.AUTO."
5170+
+ (
5171+
"\n\nUser stack:\n" + "".join(user_stack.format())
5172+
if user_stack
5173+
else ""
5174+
)
51665175
)
51675176
record_constraint_violation(
51685177
constraint.warn_only, self._debug_name(source), msg
@@ -6371,6 +6380,10 @@ def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None:
63716380
},
63726381
)
63736382

6383+
for source in self.var_to_sources.get(a, []):
6384+
if user_tb:
6385+
self.specialization_stacks[source] = user_tb
6386+
63746387
if config.print_specializations:
63756388
self.log.warning(
63766389
"Specializing %s to %s", self.var_to_sources[a][0].name(), tgt

0 commit comments

Comments
 (0)
0