10000 [dynamo] Emit warning on global module hooks when calling using output of `torch.compile(module)` by StrongerXi · Pull Request #152740 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] Emit warning on global module hooks when calling using output of torch.compile(module) #152740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

StrongerXi
Copy link
Contributor
@StrongerXi StrongerXi commented May 2, 2025

Stack from ghstack (oldest at bottom):

When we do torch.compile(module), we eventually end up returning a new
OptimizedModule instance, whose forward method is the result of
torch.compile(mod.__call__), meaning it already captures all the extra
logic (e.g., hook firing) for the compiled module.

OptimizedModule also inherits nn.module.__call__, and thus
has its own hook logic. This is useful for torchao, which injects module
forward hooks to run in eager for quantization purposes.

However, this might create unexpected behavior for global module hooks,
because torch.compile(module) causes the hook to fire one extra time
for OptimizedModule, when compared to eager.

To preserve BC, we simply emit a warning for this behavior, and let
users decide what to do. This is reasonable because the global module
hooks are documented to be used for debugging/profiling purposes only.

Fixes #149502

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

Differential Revision: D74611716

…compile(mod)`

When we do `torch.compile(mod)`, we eventually end up returning a new
module instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) from the default `torch.nn.Module.__call__`.
As a result we can't reuse the inherited default `__call__` as is,
because we'd end up running the logic twice.

This patch makes the returned `OptimizedModule` override the default
`__call__`, and directly calls into its compiled `forward` method.

Fixes #149502

[ghstack-poisoned]
Copy link
pytorch-bot bot commented May 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152740

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit f8c4dc1 with merge base 7243c69 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor
@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #152741

pytorchmergebot pushed a commit that referenced this pull request May 6, 2025
)

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

Pull Request resolved: #152741
Approved by: https://github.com/anijain2305
ghstack dependencies: #152740
@huydhn
Copy link
Contributor
huydhn commented May 8, 2025

@pytorchbot revert -m 'Discuss with the author to revert and reland this' -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request May 8, 2025
…)` (#152741)"

This reverts commit 6c025b5.

Reverted #152741 on behalf of https://github.com/huydhn due to Discuss with the author to revert and reland this ([comment](#152740 (comment)))
pytorchmergebot added a commit that referenced this pull request May 8, 2025
… `torch.compile(mod)` (#152740)"

This reverts commit 0886d40.

Reverted #152740 on behalf of https://github.com/huydhn due to Discuss with the author to revert and reland this ([comment](#152740 (comment)))
@pytorchmergebot
Copy link
Collaborator

@StrongerXi your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels May 8, 2025
…der `torch.compile(mod)`"

When we do `torch.compile(mod)`, we eventually end up returning a new
module instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) from the default `torch.nn.Module.__call__`.
As a result we can't reuse the inherited default `__call__` as is,
because we'd end up running the logic twice.

This patch makes the returned `OptimizedModule` override the default
`__call__`, and directly calls into its compiled `forward` method.

Fixes #149502

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@StrongerXi StrongerXi changed the title [dynamo] Avoid running torch.nn.Module.__call__ twice under torch.compile(mod) [dynamo] Emit warning on global module hooks when calling using output of torch.compile(module) May 9, 2025
@StrongerXi StrongerXi requested review from anijain2305 and zou3519 May 9, 2025 22:26
@@ -118,6 +118,21 @@ def __setstate__(self, state: dict):
_global_forward_hooks_always_called: dict[int, bool] = OrderedDict()
_global_forward_hooks_with_kwargs: dict[int, bool] = OrderedDict()


def _has_any_global_hook():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why you want from this but if you want minimal overhead, you want boolean conversions here:

  return not(_global_backward_pre_hooks or _global_backward_hooks ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always found implicit bool(container_object) a little too implicit, but I think that's just me being not pythonic enough, updated:).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit confusing sometimes I agree. But in this case, we saw in other cases that is a few orders of magnitude faster... So I guess worth it...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perf is worth it given that these checks are called in the hotpath

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(it's not clear to me if you resolved this already or not)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already updated:).

…using output of `torch.compile(module)`"


When we do `torch.compile(module)`, we eventually end up returning a new
`OptimizedModule` instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) for the compiled module.

`OptimizedModule` also inherits `nn.module.__call__`, and thus
has its own hook logic. This is useful for torchao, which injects module
forward hooks to run in eager for quantization purposes.

However, this might create unexpected behavior for global module hooks,
because `torch.compile(module)` causes the hook to fire one extra time
for `OptimizedModule`, when compared to eager.

To preserve BC, we simply emit a warning for this behavior, and let
users decide what to do. This is reasonable because the global module
hooks are documented to be used for debugging/profiling purposes only.

Fixes #149502

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@StrongerXi
Copy link
Contributor Author

*kwargs --> **kwargs

@StrongerXi
Copy link
Contributor Author

@StrongerXi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 12, 2025
"`OptimizedModule` created by `torch.compile(module)`. If this "
"causes undesired behavior, please try using `module.compile()`"
", or use the per-module hooks instead",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you figure out the right int to as stacklevel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

return args

mod = Mod()
torch.nn.modules.module.register_module_forward_pre_hook(hook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this pollute the state of all the other tests? Can we do a try-finally where we deregister the hook in the finally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good call lol, updated.

…using output of `torch.compile(module)`"


When we do `torch.compile(module)`, we eventually end up returning a new
`OptimizedModule` instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) for the compiled module.

`OptimizedModule` also inherits `nn.module.__call__`, and thus
has its own hook logic. This is useful for torchao, which injects module
forward hooks to run in eager for quantization purposes.

However, this might create unexpected behavior for global module hooks,
because `torch.compile(module)` causes the hook to fire one extra time
for `OptimizedModule`, when compared to eager.

To preserve BC, we simply emit a warning for this behavior, and let
users decide what to do. This is reasonable because the global module
hooks are documented to be used for debugging/profiling purposes only.

Fixes #149502

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

Differential Revision: [D74611716](https://our.internmc.facebook.com/intern/diff/D74611716)

[ghstack-poisoned]
…using output of `torch.compile(module)`"


When we do `torch.compile(module)`, we eventually end up returning a new
`OptimizedModule` instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) for the compiled module.

`OptimizedModule` also inherits `nn.module.__call__`, and thus
has its own hook logic. This is useful for torchao, which injects module
forward hooks to run in eager for quantization purposes.

However, this might create unexpected behavior for global module hooks,
because `torch.compile(module)` causes the hook to fire one extra time
for `OptimizedModule`, when compared to eager.

To preserve BC, we simply emit a warning for this behavior, and let
users decide what to do. This is reasonable because the global module
hooks are documented to be used for debugging/profiling purposes only.

Fixes #149502

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

Differential Revision: [D74611716](https://our.internmc.facebook.com/intern/diff/D74611716)

[ghstack-poisoned]
@StrongerXi
Copy link
Contributor Author

@StrongerXi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #152741

pytorchmergebot pushed a commit that referenced this pull request May 14, 2025
)

This is essentially a follow-up on #122098, where we added support of
`getattr` and `setattr` on result of `torch.compile(module)`, but didn't
add support for `delattr`.

Fixes #150711.

Pull Request resolved: #152741
Approved by: https://github.com/anijain2305
ghstack dependencies: #152740
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo Reverted topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0