8000 [recompiles] suggest whitelist for dynamic shape recompilations by pianpwk · Pull Request #153442 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[recompiles] suggest whitelist for dynamic shape recompilations #153442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions test/dynamo/test_guard_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def test_global_state_guard(self):
"""\
GuardDebugInfo(
result=0,
verbose_code_parts=['GLOBAL_STATE changed: default_dtype '],
verbose_code_parts=[],
failure_reasons=['GLOBAL_STATE changed: default_dtype '],
num_guards_executed=0)
""",
)
Expand All @@ -92,7 +93,8 @@ def test_global_state_guard(self):
"""\
GuardDebugInfo(
result=0,
verbose_code_parts=['GLOBAL_STATE changed: deterministic_algorithms '],
verbose_code_parts=[],
failure_reasons=['GLOBAL_STATE changed: deterministic_algorithms '],
num_guards_executed=0)
""",
)
Expand Down Expand Up @@ -307,10 +309,7 @@ def test_tensor_match_guard(self):
x = torch.randn(4, 4)
x.t_()
debug_info = guard_manager.check_verbose(x)
print(debug_info.verbose_code_parts[0])
self.assertTrue(
"tensor 'x' stride mismatch" in debug_info.verbose_code_parts[0]
)
self.assertTrue("tensor 'x' stride mismatch" in debug_info.failure_reasons[0])

def test_no_tensor_aliasing_guard(self):
guard_manager = RootGuardManager()
Expand Down Expand Up @@ -694,7 +693,7 @@ def fn(x):
self.assertFalse(guard_manager.check(None))
debug_info = guard_manager.check_verbose(None)
self.assertFalse(debug_info.result)
self.assertTrue("Test" in debug_info.verbose_code_parts[0])
self.assertTrue("Test" in debug_info.failure_reasons[0])

def test_dict_contains_guard(self):
foo = {"a": 1, "b": 2}
Expand Down
72 changes: 72 additions & 0 deletions test/dynamo/test_recompile_ux.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Owner(s): ["module: dynamo"]
import re
import unittest
import weakref

Expand Down Expand Up @@ -183,6 +184,77 @@ def cache_fail_test(cached_input, missed_input, expected_failure):
"tensor 'a' requires_grad mismatch. expected requires_grad=0",
)

@torch._dynamo.config.patch(recompile_limit=8)
def test_dynamic_whitelist_suggestion(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(4, 4)
self.attr = torch.randn(4)
self.mode = "a"

def forward(self, x):
if self.mode == "a":
return self.lin(x) + self.attr
else:
return self.lin(x) - self.attr

def run():
torch.compiler.reset()
opt_f = torch.compile(Foo(), backend="eager", dynamic=False)
opt_f(torch.randn(7, 4))
opt_f.mode = "b"
opt_f.lin = torch.nn.Linear(5, 5)
opt_f.attr = torch.randn(5)
opt_f(torch.randn(9, 5))

with log_settings(kwargs_to_settings(recompiles=True)):
with self.assertLogs(
logger=torch._dynamo.guards.recompiles_log, level="DEBUG"
) as logs:
run()

# check recompile reason only points to .mode attribute
re_msg = (
r".*triggered by the following guard failure\(s\)(.*\n)*.*"
r"0\/0: self.mode == \'a\'(.*\n)*.*"
r"Multiple size mismatches found(.*\n)*.*"
)
self.assertRegex(logs.records[0].msg, re_msg)
self.assertTrue("size mismatch at index" not in logs.records[0].msg)

# check whitelist
match = re.search(
r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', logs.records[0].msg
)
self.assertTrue(match is not None)
whitelist = match.group(1)
for src_name in [
"L['x']",
"L['self'].attr",
"L['self']._modules['lin']._parameters['bias']",
"L['self']._modules['lin']._parameters['weight']",
]:
self.assertTrue(src_name in whitelist)

with log_settings(kwargs_to_settings(recompiles_verbose=True)):
with self.assertLogs(
logger=torch._dynamo.guards.recompiles_verbose_log, level="DEBUG"
) as logs:
run()

# check all recompile reasons
re_msg = (
r".*triggered by the following guard failure\(s\)(.*\n)*.*"
r"0\/0: self.mode == \'a\'(.*\n)*.*"
r"tensor 'x' size mismatch at index 0(.*\n)*.*"
r"tensor 'self.attr' size mismatch at index 0(.*\n)*.*"
r"tensor 'self._modules\['lin'\]._parameters\['bias'\]' size mismatch at index 0(.*\n)*.*"
r"tensor 'self._modules\['lin'\]._parameters\['weight'\]' size mismatch at index 0(.*\n)*.*"
r"Multiple size mismatches found(.*\n)*.*"
)
self.assertRegex(logs.records[0].msg, re_msg)

def test_mismatched_type(self):
a = torch.rand(3, 4, 5)
b = torch.rand(3, 4, 5)
Expand Down
109 changes: 93 additions & 16 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import logging
import math
import pickle
import re
import sys
import textwrap
import types
Expand Down Expand Up @@ -3273,38 +3274,43 @@ def get_guard_fail_reason_helper(
guard_manager: GuardFn,
f_locals: dict[str, object],
compile_id: CompileId,
) -> str:
return_all: bool = False,
) -> Union[str, tuple[str, list[str], list[str]]]:
"""
Return the reason why `guard_manager` failed.
Updates `guard_failures` with the generated reason.
Only the first failed check of guard_manager is reported.
Only the first failed check of guard_manager is reported,
unless TORCH_LOGS="recompiles_verbose" is specified.
"""
scope = {"L": f_locals, "G": guard_manager.global_scope["G"]}
scope.update(guard_manager.closure_vars)
reasons: list[str] = []

no_tensor_aliasing_check_failed = False

failure_reasons: list[str] = []
verbose_code_parts: list[str] = []
guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined]
# For test_export_with_map_cond, the check_verbose fail even without the
# C++ guard manager. We need to fix the issue to remove the comment.
# assert not guard_debug_info.result
if not guard_debug_info.result:
verbose_code_parts = guard_debug_info.verbose_code_parts
# verbose_code_parts is either the actual reason (e.g. in case of
# TENSOR_MATCH) or it could be a list of verbose_code_part that we
failure_reasons = guard_debug_info.failure_reasons
# verbose_code_parts is a list of verbose_code_part that we
# passed to the leaf guard at construction time. If its a list, we
# walk through this list and find the guard that failed. This is
# very important for symbolic shape guards which are currently
# installed as a lambda guard and can encompass a long list of code_parts.
# failure_reasons is a list of non-eval-able strings in plain english.

if len(verbose_code_parts) == 1:
if "Duplicate tensor found" in verbose_code_parts[0]:
no_tensor_aliasing_check_failed = True
else:
reasons = verbose_code_parts
verbose_code_parts = []
if any("Duplicate tensor found" in reason for reason in failure_reasons):
no_tensor_aliasing_check_failed = True
elif len(verbose_code_parts) == 1:
reasons = verbose_code_parts
verbose_code_parts = []
elif len(verbose_code_parts) == 0:
reasons = []

if no_tensor_aliasing_check_failed:
reasons = recompilation_reason_for_no_tensor_aliasing_guard(
Expand Down Expand Up @@ -3332,19 +3338,28 @@ def get_guard_fail_reason_helper(
if not is_recompiles_verbose_enabled():
break

reason_str = f"{compile_id}: " + "; ".join(reasons)
return strip_local_scope(reason_str)
all_reasons = reasons + failure_reasons
if is_recompiles_verbose_enabled():
reason_str = strip_local_scope(f"{compile_id}: " + "; ".join(all_reasons))
else:
reason_str = strip_local_scope(f"{compile_id}: " + "; ".join(all_reasons[:1]))

if return_all:
return reason_str, reasons, failure_reasons
return reason_str


def get_guard_fail_reason(
guard_manager: GuardFn,
code: types.CodeType,
f_locals: dict[str, object],
compile_id: CompileId,
) -> str:
) -> tuple[str, list[str], list[str]]:
if isinstance(guard_manager, DeletedGuardManagerWrapper):
return f"{compile_id}: {guard_manager.invalidation_reason}"
reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id)
return f"{compile_id}: {guard_manager.invalidation_reason}", [], []
reason_str, code_reasons, fail_reasons = get_guard_fail_reason_helper( # type: ignore[misc]
guard_manager, f_locals, compile_id, return_all=True
)
guard_failures[orig_code_map[code]].append(reason_str)

try:
Expand All @@ -3357,6 +3372,55 @@ def get_guard_fail_reason(
"Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval",
)

return reason_str, code_reasons, fail_reasons


def _extract_recompiled_dynamic_sources(
fail_reasons: list[str], frame_locals: dict[str, object], guard_manager: GuardFn
) -> tuple[list[str], bool]:
"""
Goes through a list of guard failure reasons, and extracts source names for tensors we've dynamically recompiled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm so one downside of this opt-in is by default users won't get this for their first mast run?

Returns a list of sources, and a boolean indicating whether a parameter was included.
"""
scope = {"L": frame_locals, "G": guard_manager.global_scope["G"]}
scope.update(guard_manager.closure_vars)
has_parameter = False
shape_sources: OrderedSet[str] = OrderedSet()
pattern = r"tensor '(.*)' size mismatch at index .* expected (\d+), actual (\d+).*"
for reason in fail_reasons:
if (match := re.search(pattern, reason)) is not None:
name, size_orig, size_new = match.groups()
tensor = eval(name, scope)
size_orig, size_new = int(size_orig), int(size_new)
if (
size_orig >= 2 and size_new >= 2
): # these suggestions won't help for 0/1 specialization.
if isinstance(tensor, torch.nn.Parameter):
has_parameter = True
shape_sources.add(name)
return list(shape_sources), has_parameter


def _suggest_dynamic_whitelist(
dynamic_sources: OrderedSet[str], has_dynamic_parameter: bool
) -> str:
"""
Adds suggested dynamic whitelist to recompile logging, based on detected sources.
"""
reason_str = ""
if dynamic_sources:
reason_str += "\n"
if len(dynamic_sources) > 1:
reason_str += "\nMultiple size mismatches found. "
reason_str += (
"The following environment variable would enable dynamic compilation to start, avoiding this recompile: "
+ f'TORCH_COMPILE_DYNAMIC_SOURCES="{",".join(dynamic_sources)}"'
)
if has_dynamic_parameter:
reason_str += (
"\nSize guard failed on a parameter, consider using torch._dynamo.config.force_parameter_static_shapes = False "
+ "to allow dynamism on parameters."
)
return reason_str


Expand All @@ -3369,15 +3433,25 @@ def get_and_maybe_log_recompilation_reasons(
Raises a RecompileError if `config.error_on_recompile` is enabled.
"""
reasons = []
dynamic_sources: OrderedSet[str] = OrderedSet()
has_dynamic_parameter = False
while cache_entry is not None:
reason = get_guard_fail_reason(
reason, _, fail_reasons = get_guard_fail_reason(
cache_entry.guard_manager,
cache_entry.code,
frame.f_locals,
cache_entry.compile_id,
)
if reason:
reasons.append(reason)
if not isinstance(cache_entry.guard_manager, DeletedGuardManagerWrapper):
new_dynamic_sources, dynamic_param = (
_extract_recompiled_dynamic_sources(
fail_reasons, frame.f_locals, cache_entry.guard_manager
)
)
dynamic_sources.update(new_dynamic_sources)
has_dynamic_parameter |= dynamic_param
cache_entry = cache_entry.next

code = frame.f_code
Expand All @@ -3396,6 +3470,9 @@ def get_and_maybe_log_recompilation_reasons(
guard_failure_details = (
f"triggered by the following guard failure(s):\n{failures}"
)
guard_failure_details += _suggest_dynamic_whitelist(
dynamic_sources, has_dynamic_parameter
)
message = (
f"Recompiling function {code.co_name} in {code.co_filename}:{code.co_firstlineno}\n"
f"{textwrap.indent(guard_failure_details, ' ')}"
Expand Down
Loading
Loading
0