10000 [WIP] suggest whitelist for dynamic shape recompilations by pianpwk · Pull Request #153442 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[WIP] suggest whitelist for dynamic shape recompilations #153442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

pianpwk
Copy link
Contributor
@pianpwk pianpwk commented May 13, 2025

More processing of recompilation reasons, to detect tensor sources we've recompiled dynamically, and suggest the dynamic whitelist to reduce recompilations.

Refactors GuardDebugInfo to hold both verbose_code_parts and failure_reasons fields, the former containing eval-able code, the latter reasons in plain english; previously they were combined in the same field and detecting what to eval/pattern-match was difficult.

For this toy example:

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(4, 4)
        self.attr = torch.randn(4)
        self.mode = "a"
    def forward(self, x, y):
        if self.mode == "a":
            return self.lin(x * 2) + self.attr + y
        else:
            return self.lin(x / 2) - self.attr - y

# 1
m = Foo()
fn = torch.compile(m)
fn(torch.randn(7, 4), torch.randn(4))

# 2
fn.mode = "b"
fn(torch.randn(7, 4), torch.randn(4))

# 3
fn.lin = torch.nn.Linear(8, 5)
fn.attr = torch.randn(5)
fn.mode = "b"
fn(torch.randn(9, 8), torch.randn(5))

logs with TORCH_LOGS="recompiles":
(first recompile)

V0512 19:49:21.987000 2692736 torch/_dynamo/guards.py:3451] [0/1] [__recompiles] Recompiling function forward in /data/users/pianpwk/pytorch/custom_tests/test_recompiles_tlparse.py:14
V0512 19:49:21.987000 2692736 torch/_dynamo/guards.py:3451] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0512 19:49:21.987000 2692736 torch/_dynamo/guards.py:3451] [0/1] [__recompiles]     - 0/0: self.mode == 'a'                                       

(second recompile)

V0512 19:49:22.098000 2692736 torch/_dynamo/guards.py:3451] [0/2] [__recompiles] Recompiling function forward in /data/users/pianpwk/pytorch/custom_tests/test_recompiles_tlparse.py:14
V0512 19:49:22.098000 2692736 torch/_dynamo/guards.py:3451] [0/2] [__recompiles]     triggered by the following guard failure(s):
V0512 19:49:22.098000 2692736 torch/_dynamo/guards.py:3451] [0/2] [__recompiles]     - 0/1: tensor 'x' size mismatch at index 0. expected 7, actual 9
V0512 19:49:22.098000 2692736 torch/_dynamo/guards.py:3451] [0/2] [__recompiles]     - 0/0: self.mode == 'a'                                       
V0512 19:49:22.098000 2692736 torch/_dynamo/guards.py:3451] [0/2] [__recompiles] 
V0512 19:49:22.098000 2692736 torch/_dynamo/guards.py:3451] [0/2] [__recompiles]     Multiple size mismatches found. The following environment variable would enable dynamic compilation to start, avoiding this recompile: TORCH_COMPILE_DYNAMIC_SOURCES="L['x'],L['y'],L['self'].attr,L['self']._modules['lin']._parameters['bias'],L['self']._modules['lin']._parameters['weight']"
V0512 19:49:22.098000 2692736 torch/_dynamo/guards.py:3451] [0/2] [__recompiles]     Size guard failed on a parameter, consider using torch._dynamo.config.force_parameter_static_shapes = False to allow dynamism on parameters.

logs with TORCH_LOGS="recompiles_verbose":
(first recompile)

V0512 19:51:39.792000 2709017 torch/_dynamo/guards.py:3449] [0/1] [__recompiles_verbose] Recompiling function forward in /data/users/pianpwk/pytorch/custom_tests/test_recompiles_tlparse.py:14
V0512 19:51:39.792000 2709017 torch/_dynamo/guards.py:3449] [0/1] [__recompiles_verbose]     triggered by the following guard failure(s):
V0512 19:51:39.792000 2709017 torch/_dynamo/guards.py:3449] [0/1] [__recompiles_verbose]     guard 0 failures:
V0512 19:51:39.792000 2709017 torch/_dynamo/guards.py:3449] [0/1] [__recompiles_verbose]     - 0/0: self.mode == 'a'

(second recompile)

V0512 19:51:39.865000 2709017 torch/_dynamo/guards.py:3449] [0/2] [__recompiles_verbose] Recompiling function forward in /data/users/pianpwk/pytorch/custom_tests/test_recompiles_tlparse.py:14
V0512 19:51:39.865000 2709017 torch/_dynamo/guards.py:3449] [0/2] [__recompiles_verbose]     triggered by the following guard failure(s):
V0512 19:51:39.865000 2709017 torch/_dynamo/guards.py:3449] [0/2] [__recompiles_verbose]     guard 0 failures:
V0512 19:51:39.865000 2709017 torch/_dynamo/guards.py:3449] [0/2] [__recompiles_verbose]     - 0/1: tensor 'x' size mismatch at index 0. expected 7, actual 9; tensor 'y' size mismatch at index 0. expected 4, actual 5; tensor 'self.attr' size mismatch at index 0. expected 4, actual 5; tensor 'self._modules['lin']._parameters['bias']' size mismatch at index 0. expected 4, actual 5; tensor 'self._modules['lin']._parameters['weight']' size mismatch at index 0. expected 4, actual 5
V0512 19:51:39.865000 2709017 torch/_dynamo/guards.py:3449] [0/2] [__recompiles_verbose] 
V0512 19:51:39.865000 2709017 torch/_dynamo/guards.py:3449] [0/2] [__recompiles_verbose]     guard 1 failures:
V0512 19:51:39.865000 2709017 torch/_dynamo/guards.py:3449] [0/2] [__recompiles_verbose]     - 0/0: self.mode == 'a'                                       ; tensor 'self.attr' size mismatch at index 0. expected 4, actual 5; tensor 'self._modules['lin']._parameters['bias']' size mismatch at index 0. expected 4, actual 5; tensor 'self._modules['lin']._parameters['weight']' size mismatch at index 0. expected 4, actual 5; tensor 'x' size mismatch at index 0. expected 7, actual 9; tensor 'y' size mismatch at index 0. expected 4, actual 5
V0512 19:51:39.865000 2709017 torch/_dynamo/guards.py:3449] [0/2] [__recompiles_verbose] 
V0512 19:51:39.865000 2709017 torch/_dynamo/guards.py:3449] [0/2] [__recompiles_verbose]     Multiple size mismatches found. The following environment variable would enable dynamic compilation to start, avoiding this recompile: TORCH_COMPILE_DYNAMIC_SOURCES="L['x'],L['y'],L['self'].attr,L['self']._modules['lin']._parameters['bias'],L['self']._modules['lin']._parameters['weight']"
V0512 19:51:39.865000 2709017 torch/_dynamo/guards.py:3449] [0/2] [__recompiles_verbose]     Size guard failed on a parameter, consider using torch._dynamo.config.force_parameter_static_shapes = False to allow dynamism on parameters.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

Copy link
pytorch-bot bot commented May 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153442

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 6 New Failures

As of commit a11c15e with merge base 3aa8477 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pianpwk pianpwk changed the title initw [WIP] suggest whitelist for dynamic shape recompilations May 13, 2025
"Replicate": Replicate,
"Partial": Partial,
"DeviceMesh": DeviceMesh,
})
Copy link
Contributor Author
@pianpwk pianpwk May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eval was failing due to local imports here:

if torch.distributed.is_available():
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
Partial,
Replicate,
Shard,
)
ok_types = ok_types + (
Shard,
Replicate,
Partial,
DeviceMesh,
)

code_part = f"{ref}.__tensor_flatten__()[1] == {original_metadata}"
self.get_guard_manager(guard).add_lambda_guard(
metadata_checker, get_verbose_code_parts(code_part, guard)
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__check_metadata wasn't registered as callable in closure_vars, so wasn't eval-able

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant
0