8000 [dynamo] Account for function id reuse in relevant Dynamo decorators by StrongerXi · Pull Request #148385 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] Account for function id reuse in relevant Dynamo decorators #148385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

StrongerXi
Copy link
Contributor
@StrongerXi StrongerXi commented Mar 4, 2025

Stack from ghstack (oldest at bottom):

This fixes a recent series of flaky failure from nonstrict_trace unit
tests: #148166, #148056, #148055, #148054, #148034, #148033, #148032, #148031.

For now we don't need to worry about the other decorators because they
are either meant for builtin/numpy functions (which should never
deallocate in practice), or used for polyfills which keeps the function
object in get_torch_obj_rule_map().

Fixes #147777.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Mar 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148385

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f6ee148 with merge base aade4fb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Comment on lines 165 to 166
trace_rules._allowed_callable_ids.remove(fn_id)
trace_rules._nonstrict_trace_callable_ids.remove(fn_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to remove stuff from _disallowed_callable_ids?

Copy link
Contributor Author
@StrongerXi StrongerXi Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOps copy paste error.

Removed trace_rules._nonstrict_trace_callable_ids.remove(fn_id), and we don't need to touch _disallowed_callable_ids because allow_in_graph was removing fn_id from it, so it's safe.

Copy link
Contributor
@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #148007

pytorchmergebot pushed a commit that referenced this pull request Mar 5, 2025
…ant`-ed instances (#148007)

As title, this enables `nonstrict_trace`-ed function to take in object
whose type has been `pytree.register_constant`-ed, as long as the object
existed outside the `torch.compile` region. This also forces Dynamo to
emit a `EQUALS_MATCH` guard on the object.

Pull Request resolved: #148007
Approved by: https://github.com/zou3519
ghstack dependencies: #148385
pytorchmergebot pushed a commit that referenced this pull request Mar 5, 2025
This fixes a recent series of flaky failure from `nonstrict_trace` unit
tests: #148166, #148056, #148055, #148054, #148034, #148033, #148032, #148031.

For now we don't need to worry about the other decorators because they
are either meant for builtin/numpy functions (which should never
deallocate in practice), or used for polyfills which keeps the function
object in `get_torch_obj_rule_map()`.

Fixes #147777.

ghstack-source-id: d9bea5f
Pull Request resolved: #148385
@github-actions github-actions bot deleted the gh/StrongerXi/90/head branch April 11, 2025 02:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Decorators like torch.compiler.allow_in_graph doesn't account for id reuse
3 participants
0