8000 [dynamo] Emit warning on global module hooks when calling using output of `torch.compile(module)` by StrongerXi · Pull Request #152740 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] Emit warning on global module hooks when calling using output of torch.compile(module) #152740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

8000 Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions test/dynamo/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,36 @@ def forward(self, x):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)

def test_global_module_forward_pre_hook(self):
class Mod(torch.nn.Module):
def forward(self, x):
return x - 1

counter = 0

def hook(mod, args):
nonlocal counter
counter += 1
return args

x = torch.rand(18, 18)
mod = Mod()
compiled_mod = torch.compile(mod, backend="eager")

try:
hook_handle = torch.nn.modules.module.register_module_forward_pre_hook(hook)
ref = mod(x)
self.assertEqual(counter, 1)
with self.assertWarnsRegex(
UserWarning,
r"Using `torch.compile\(module\)` when there are global hooks.*",
):
res = compiled_mod(x)
self.assertEqual(counter, 3)
self.assertEqual(ref, res)
finally:
hook_handle.remove()


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
13 changes: 13 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,19 @@ def _initialize(self):
self._forward = self.forward
self.forward = self._call_lazy_check

def __call__(self, *args, **kwargs):
if torch.nn.modules.module._has_any_global_hook():
warnings.warn(
"Using `torch.compile(module)` when there are global hooks on "
"modules (e.g., from `register_module_forward_hook`); this will"
" cause the hooks to fire an extra time for the "
"`OptimizedModule` created by `torch.compile(module)`. If this "
"causes undesired behavior, please try using `module.compile()`"
", or use the per-module hooks instead",
stacklevel=2,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you figure out the right int to as stacklevel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

return super().__call__(*args, **kwargs)

def __reduce__(self):
return (self.__class__, (self._orig_mod, self.dynamo_ctx))

Expand Down
12 changes: 12 additions & 0 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ def __setstate__(self, state: dict):
_global_forward_hooks_always_called: dict[int, bool] = OrderedDict()
_global_forward_hooks_with_kwargs: dict[int, bool] = OrderedDict()


def _has_any_global_hook():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why you want from this but if you want minimal overhead, you want boolean conversions here:

  return not(_global_backward_pre_hooks or _global_backward_hooks ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always found implicit bool(container_object) a little too implicit, but I think that's just me being not pythonic enough, updated:).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit confusing sometimes I agree. But in this case, we saw in other cases that is a few orders of magnitude faster... So I guess worth it...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perf is worth it given that these checks are called in the hotpath

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(it's not clear to me if you resolved this already or not)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already updated:).

return (
_global_backward_pre_hooks
or _global_backward_hooks
or _global_forward_pre_hooks
or _global_forward_hooks
or _global_forward_hooks_always_called
or _global_forward_hooks_with_kwargs
)


_EXTRA_STATE_KEY_SUFFIX = "_extra_state"


Expand Down
Loading
0