-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[dynamo] Emit warning on global module hooks when calling using output of torch.compile(module)
#152740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…compile(mod)` When we do `torch.compile(mod)`, we eventually end up returning a new module instance, whose `forward` method is the result of `torch.compile(mod.__call__)`, meaning it already captures all the extra logic (e.g., hook firing) from the default `torch.nn.Module.__call__`. As a result we can't reuse the inherited default `__call__` as is, because we'd end up running the logic twice. This patch makes the returned `OptimizedModule` override the default `__call__`, and directly calls into its compiled `forward` method. Fixes #149502 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152740
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit f8c4dc1 with merge base 7243c69 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Starting merge as part of PR stack under #152741 |
) This is essentially a follow-up on #122098, where we added support of `getattr` and `setattr` on result of `torch.compile(module)`, but didn't add support for `delattr`. Fixes #150711. Pull Request resolved: #152741 Approved by: https://github.com/anijain2305 ghstack dependencies: #152740
@pytorchbot revert -m 'Discuss with the author to revert and reland this' -c ghfirst |
@pytorchbot successfully started a revert job. Check the current status here. |
…)` (#152741)" This reverts commit 6c025b5. Reverted #152741 on behalf of https://github.com/huydhn due to Discuss with the author to revert and reland this ([comment](#152740 (comment)))
… `torch.compile(mod)` (#152740)" This reverts commit 0886d40. Reverted #152740 on behalf of https://github.com/huydhn due to Discuss with the author to revert and reland this ([comment](#152740 (comment)))
@StrongerXi your PR has been successfully reverted. |
…der `torch.compile(mod)`" When we do `torch.compile(mod)`, we eventually end up returning a new module instance, whose `forward` method is the result of `torch.compile(mod.__call__)`, meaning it already captures all the extra logic (e.g., hook firing) from the default `torch.nn.Module.__call__`. As a result we can't reuse the inherited default `__call__` as is, because we'd end up running the logic twice. This patch makes the returned `OptimizedModule` override the default `__call__`, and directly calls into its compiled `forward` method. Fixes #149502 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
torch.nn.Module.__call__
twice under torch.compile(mod)
torch.compile(module)
@@ -118,6 +118,21 @@ def __setstate__(self, state: dict): | |||
_global_forward_hooks_always_called: dict[int, bool] = OrderedDict() | |||
_global_forward_hooks_with_kwargs: dict[int, bool] = OrderedDict() | |||
|
|||
|
|||
def _has_any_global_hook(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why you want from this but if you want minimal overhead, you want boolean conversions here:
return not(_global_backward_pre_hooks or _global_backward_hooks ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I always found implicit bool(container_object)
a little too implicit, but I think that's just me being not pythonic enough, updated:).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a bit confusing sometimes I agree. But in this case, we saw in other cases that is a few orders of magnitude faster... So I guess worth it...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perf is worth it given that these checks are called in the hotpath
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(it's not clear to me if you resolved this already or not)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already updated:).
…using output of `torch.compile(module)`" When we do `torch.compile(module)`, we eventually end up returning a new `OptimizedModule` instance, whose `forward` method is the result of `torch.compile(mod.__call__)`, meaning it already captures all the extra logic (e.g., hook firing) for the compiled module. `OptimizedModule` also inherits `nn.module.__call__`, and thus has its own hook logic. This is useful for torchao, which injects module forward hooks to run in eager for quantization purposes. However, this might create unexpected behavior for global module hooks, because `torch.compile(module)` causes the hook to fire one extra time for `OptimizedModule`, when compared to eager. To preserve BC, we simply emit a warning for this behavior, and let users decide what to do. This is reasonable because the global module hooks are documented to be used for debugging/profiling purposes only. Fixes #149502 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
|
@StrongerXi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
"`OptimizedModule` created by `torch.compile(module)`. If this " | ||
"causes undesired behavior, please try using `module.compile()`" | ||
", or use the per-module hooks instead", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you figure out the right int to as stacklevel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
test/dynamo/test_hooks.py
Outdated
return args | ||
|
||
mod = Mod() | ||
torch.nn.modules.module.register_module_forward_pre_hook(hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this pollute the state of all the other tests? Can we do a try-finally where we deregister the hook in the finally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a good call lol, updated.
…using output of `torch.compile(module)`" When we do `torch.compile(module)`, we eventually end up returning a new `OptimizedModule` instance, whose `forward` method is the result of `torch.compile(mod.__call__)`, meaning it already captures all the extra logic (e.g., hook firing) for the compiled module. `OptimizedModule` also inherits `nn.module.__call__`, and thus has its own hook logic. This is useful for torchao, which injects module forward hooks to run in eager for quantization purposes. However, this might create unexpected behavior for global module hooks, because `torch.compile(module)` causes the hook to fire one extra time for `OptimizedModule`, when compared to eager. To preserve BC, we simply emit a warning for this behavior, and let users decide what to do. This is reasonable because the global module hooks are documented to be used for debugging/profiling purposes only. Fixes #149502 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Differential Revision: [D74611716](https://our.internmc.facebook.com/intern/diff/D74611716) [ghstack-poisoned]
…using output of `torch.compile(module)`" When we do `torch.compile(module)`, we eventually end up returning a new `OptimizedModule` instance, whose `forward` method is the result of `torch.compile(mod.__call__)`, meaning it already captures all the extra logic (e.g., hook firing) for the compiled module. `OptimizedModule` also inherits `nn.module.__call__`, and thus has its own hook logic. This is useful for torchao, which injects module forward hooks to run in eager for quantization purposes. However, this might create unexpected behavior for global module hooks, because `torch.compile(module)` causes the hook to fire one extra time for `OptimizedModule`, when compared to eager. To preserve BC, we simply emit a warning for this behavior, and let users decide what to do. This is reasonable because the global module hooks are documented to be used for debugging/profiling purposes only. Fixes #149502 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Differential Revision: [D74611716](https://our.internmc.facebook.com/intern/diff/D74611716) [ghstack-poisoned]
@StrongerXi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Starting merge as part of PR stack under #152741 |
) This is essentially a follow-up on #122098, where we added support of `getattr` and `setattr` on result of `torch.compile(module)`, but didn't add support for `delattr`. Fixes #150711. Pull Request resolved: #152741 Approved by: https://github.com/anijain2305 ghstack dependencies: #152740
Stack from ghstack (oldest at bottom):
delattr
on result oftorch.compile(module)
#152741torch.compile(module)
#152740When we do
torch.compile(module)
, we eventually end up returning a newOptimizedModule
instance, whoseforward
method is the result oftorch.compile(mod.__call__)
, meaning it already captures all the extralogic (e.g., hook firing) for the compiled module.
OptimizedModule
also inheritsnn.module.__call__
, and thushas its own hook logic. This is useful for torchao, which injects module
forward hooks to run in eager for quantization purposes.
However, this might create unexpected behavior for global module hooks,
because
torch.compile(module)
causes the hook to fire one extra timefor
OptimizedModule
, when compared to eager.To preserve BC, we simply emit a warning for this behavior, and let
users decide what to do. This is reasonable because the global module
hooks are documented to be used for debugging/profiling purposes only.
Fixes #149502
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames
Differential Revision: D74611716