8000 misc. fixes to constraints warnings and errors (#100745) · pytorch/pytorch@ca9f55f · GitHub
[go: up one dir, main page]

Skip to content

Commit ca9f55f

Browse files
avikchaudhuripytorchmergebot
authored andcommitted
misc. fixes to constraints warnings and errors (#100745)
1. Move constraint violation error after constraint discovery warning, and attach them when we have both. 2. Remove verbose internal traceback for relevant guard in constraint violation error. 3. Remove mention of `assume_static_by_default` in specialization warning. 4. Fix indenting of `specializations` body and make it assert individually instead of returning a conjunction. 5. Remove return annotation on signature used in generated `specializations` and `specify_constraints` fun 10000 ctions. 6. Split `&` ranges because we don't support them yet. Differential Revision: [D45619852](https://our.internmc.facebook.com/intern/diff/D45619852/) Pull Request resolved: #100745 Approved by: https://github.com/tugsbayasgalan
1 parent 0bf9722 commit ca9f55f

File tree

3 files changed

+49
-28
lines changed

3 files changed

+49
-28
lines changed

test/test_dynamic_shapes.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,16 +1780,16 @@ def dummy_f(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
17801780
print(static_code)
17811781
expected_static = '''
17821782
def specializations(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
1783-
return (x0.size()[0] == 8 and
1784-
x1.size()[2] == 96 and
1785-
x11.size()[1] == 1 and
1786-
x12.size()[2] == 3 and
1787-
x2.size()[0] == 8 and
1788-
x3.size()[0] == 8 and
1789-
x4.size()[0] == 8 and
1790-
x5.size()[1] == 22 and
1791-
x7.size()[3] == 96 and
1792-
x8.size()[1] == 22)
1783+
assert x0.size()[0] == 8
1784+
assert x1.size()[2] == 96
1785+
assert x11.size()[1] == 1
1786+
assert x12.size()[2] == 3
1787+
assert x2.size()[0] == 8
1788+
assert x3.size()[0] == 8
1789+
assert x4.size()[0] == 8
1790+
assert x5.size()[1] == 22
1791+
assert x7.size()[3] == 96
1792+
assert x8.size()[1] == 22
17931793
'''
17941794
expected_dynamic = '''
17951795
def specify_constraints(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):

torch/_dynamo/eval_frame.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torch.fx.experimental.proxy_tensor import make_fx
2828
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
2929
from torch.nn.parallel.distributed import DistributedDataParallel
30+
3031
from ..fx import GraphModule
3132
from .backends.registry import CompilerFn, lookup_backend
3233

@@ -64,7 +65,10 @@
6465

6566
import sympy
6667

67-
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
68+
from torch.fx.experimental.symbolic_shapes import (
69+
ConstraintViolationError,
70+
StrictMinMaxConstraint,
71+
)
6872
from torch.utils._sympy.value_ranges import ValueRanges
6973

7074

@@ -861,6 +865,7 @@ def result_capturing_wrapper(*graph_inputs):
861865
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
862866

863867
remove_from_cache(f)
868+
constraint_violation_error = None
864869
with patch(f"{__name__}.most_recent_backend", None), config.patch(
865870
summarize_dim_constraints=True,
866871
specialize_int=True,
@@ -877,24 +882,35 @@ def result_capturing_wrapper(*graph_inputs):
877882
dynamic=(tracing_mode == "symbolic"),
878883
)(f)
879884
# TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.
880-
result_traced = opt_f(*args, **kwargs)
885+
try:
886+
result_traced = opt_f(*args, **kwargs)
887+
except ConstraintViolationError as e:
888+
constraint_violation_error = e
881889
remove_from_cache(f)
882890

891+
if (shape_env := getattr(fake_mode, "shape_env", None)) is not None:
892+
dim_constraints = shape_env.dim_constraints
893+
assert dim_constraints is not None
894+
dim_constraints.solve()
895+
msg = dim_constraints.prettify_results(inspect.signature(f))
896+
if constraint_violation_error:
897+
constraint_violation_error.args = (
898+
constraint_violation_error.args[0] + msg,
899+
)
900+
else:
901+
log.warning(
902+
"Summary of dimension constraints:%s",
903+
msg,
904+
)
905+
if constraint_violation_error:
906+
raise constraint_violation_error
907+
883908
assert (
884909
graph is not None
885910
), "Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph."
886911
assert out_guards is not None, "Failed to produce guards during tracing"
887912
assert fake_mode is not None
888913

889-
if (shape_env := getattr(fake_mode, "shape_env", None)) is not None:
890-
dim_constraints = shape_env.dim_constraints
891-
assert dim_constraints is not None
892-
dim_constraints.solve()
893-
log.warning(
894-
"Summary of dimension constraints:%s",
895-
dim_constraints.prettify_results(inspect.signature(f)),
896-
)
897-
898914
matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)
899915

900916
flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)

torch/fx/experimental/symbolic_shapes.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,7 +1692,11 @@ def solve(self):
16921692
try:
16931693
solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
16941694
# because this is univariate, the solution is a dynamic (range) constraint
1695-
self._dynamic_results.add(self._dcp.doprint(solution))
1695+
if isinstance(solution, sympy.And):
1696+
for arg in solution.args:
1697+
self._dynamic_results.add(self._dcp.doprint(arg))
1698+
else:
1699+
self._dynamic_results.add(self._dcp.doprint(solution))
16961700
except NotImplementedError as e:
16971701
log.warning("Failed to reduce inequalities: %s", e)
16981702
for expr in exprs:
@@ -1719,20 +1723,22 @@ def prettify_results(self, original_signature: inspect.Signature):
17191723
def unwrap_local_source(source_name):
17201724
return re.sub(r"L\['(.+?)'\]", r'\1', source_name)
17211725

1726+
signature = original_signature.replace(return_annotation=inspect.Signature.empty)
1727+
17221728
buf = ""
17231729
indent = 4 * " "
17241730
if self._static_results:
17251731
sorted_static_results = [unwrap_local_source(res) for res in sorted(self._static_results)]
17261732
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
1727-
buf += "\nNOTE: Specializations will happen by default with `assume_static_by_default=True`."
1728-
buf += f"\n```\ndef specializations{str(original_signature)}:"
1729-
buf += f"\n{indent}return (" + f" and\n{indent}".join(sorted_static_results) + ")"
1733+
buf += f"\n```\ndef specializations{str(signature)}:"
1734+
for result in sorted_static_results:
1735+
buf += f"\n{indent}assert {result}"
17301736
buf += "\n```\n"
17311737
if self._dynamic_results:
17321738
sorted_dynamic_results = sorted(self._dynamic_results)
17331739
buf += "\nThe following dimensions CAN be dynamic."
17341740
buf += "\nYou can use the following code to specify the constraints they must satisfy:"
1735-
buf += f"\n```\ndef specify_constraints{str(original_signature)}:"
1741+
buf += f"\n```\ndef specify_constraints{str(signature)}:"
17361742
buf += f"\n{indent}return ["
17371743
for result in sorted_dynamic_results:
17381744
buf += f"\n{indent*2}{unwrap_local_source(result)},"
@@ -2321,8 +2327,7 @@ def hint(s):
23212327
if isinstance(c, StrictMinMaxConstraint):
23222328
msg = (
23232329
f"Could not validate (strict) constraint {c.render(source)} as "
2324-
f"we generated a guard on this size variable: {guard_expr}. Guard "
2325-
f"was allocated at:\n{tb}"
2330+
f"we generated a guard on this size variable: {guard_expr}."
23262331
)
23272332
record_constraint_violation(c.warn_only, msg)
23282333
elif isinstance(c, RelaxedUnspecConstraint):

0 commit comments

Comments
 (0)
0