8000 Decorators like `torch.compiler.allow_in_graph` doesn't account for id reuse · Issue #147777 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Decorators like torch.compiler.allow_in_graph doesn't account for id reuse #147777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
StrongerXi opened this issue Feb 24, 2025 · 0 comments
Closed
Assignees
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@StrongerXi
Copy link
Contributor
StrongerXi commented Feb 24, 2025

🐛 Describe the bug

Context: https://github.com/pytorch/pytorch/pull/146367/files#r1964644166

Repro:

import torch

@torch.compiler.allow_in_graph
def f(x):
    return x + 1
del f

def g(x):
    return x + 2

@torch.compile(fullgraph=True, backend="eager")
def fn(x):
    return g(x)

fn(torch.ones(1))

Run it with TORCH_LOGS="graph_code":

output_graph.py:1385] [0/0] [__graph_code] TRACED GRAPH
output_graph.py:1385] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
output_graph.py:1385] [0/0] [__graph_code]  /Users/ryanguo99/Documents/work/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
output_graph.py:1385] [0/0] [__graph_code]     def forward(self, L_x_: "f32[1][1]cpu"):
output_graph.py:1385] [0/0] [__graph_code]         l_x_ = L_x_
output_graph.py:1385] [0/0] [__graph_code]
output_graph.py:1385] [0/0] [__graph_code]          # File: /Users/ryanguo99/Documents/work/scratch/allow-in-graph.py:14 in fn, code: return g(x)
output_graph.py:1385] [0/0] [__graph_code]         g: "f32[1][1]cpu" = __main___g(l_x_);  l_x_ = None
output_graph.py:1385] [0/0] [__graph_code]         return (g,)

Commenting out del f and rerun:

output_graph.py:1385] [0/0] [__graph_code] TRACED GRAPH
output_graph.py:1385] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
output_graph.py:1385] [0/0] [__graph_code]  /Users/ryanguo99/Documents/work/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
output_graph.py:1385] [0/0] [__graph_code]     def forward(self, L_x_: "f32[1][1]cpu"):
output_graph.py:1385] [0/0] [__graph_code]         l_x_ = L_x_
output_graph.py:1385] [0/0] [__graph_code]
output_graph.py:1385] [0/0] [__graph_code]          # File: /Users/ryanguo99/Documents/work/scratch/allow-in-graph.py:9 in g, code: return x + 2
output_graph.py:1385] [0/0] [__graph_code]         add: "f32[1][1]cpu" = l_x_ + 2;  l_x_ = None
output_graph.py:1385] [0/0] [__graph_code]         return (add,)

Error logs

No response

Versions

main

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

@StrongerXi StrongerXi self-assigned this Feb 24, 2025
@desertfire desertfire added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 25, 2025
pytorchmergebot pushed a commit that referenced this issue Mar 5, 2025
This fixes a recent series of flaky failure from `nonstrict_trace` unit
tests: #148166, #148056, #148055, #148054, #148034, #148033, #148032, #148031.

For now we don't need to worry about the other decorators because they
are either meant for builtin/numpy functions (which should never
deallocate in practice), or used for polyfills which keeps the function
object in `get_torch_obj_rule_map()`.

Fixes #147777.

ghstack-source-id: d9bea5f
Pull Request resolved: #148385
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
38F8
Development

Successfully merging a pull request may close this issue.

2 participants
0