8000 Revert "[dynamo] Avoid running `torch.nn.Module.__call__` twice under… · pytorch/pytorch@d36261d · GitHub
[go: up one dir, main page]

Skip to content

Commit d36261d

Browse files
Revert "[dynamo] Avoid running torch.nn.Module.__call__ twice under torch.compile(mod) (#152740)"
This reverts commit 0886d40. Reverted #152740 on behalf of https://github.com/huydhn due to Discuss with the author to revert and reland this ([comment](#152740 (comment)))
1 parent 34d424d commit d36261d

File tree

2 files changed

+0
-32
lines changed

2 files changed

+0
-32
lines changed

test/dynamo/test_hooks.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -889,32 +889,6 @@ def forward(self, x):
889889
self.assertEqual(ref, res)
890890
self.assertEqual(cnts.frame_count, 1)
891891

892-
def test_global_module_forward_pre_hook(self):
893-
class Mod(torch.nn.Module):
894-
def forward(self, x):
895-
return x - 1
896-
897-
counter = 0
898-
899-
def hook(mod, args):
900-
nonlocal counter
901-
counter += 1
902-
return (args[0] + 100,)
903-
904-
mod = Mod()
905-
torch.nn.modules.module.register_module_forward_pre_hook(hook)
906-
907-
# Case 1: torch.compile(mod)
908-
compiled_mod = torch.compile(mod, backend="eager")
909-
910-
x = torch.rand(18, 18)
911-
912-
ref = mod(x)
913-
self.assertEqual(counter, 1)
914-
res = compiled_mod(x)
915-
self.assertEqual(counter, 2)
916-
self.assertEqual(ref, res)
917-
918892

919893
if __name__ == "__main__":
920894
from torch._dynamo.test_case import run_tests

torch/_dynamo/eval_frame.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,12 +343,6 @@ def _initialize(self):
343343
self._forward = self.forward
344344
self.forward = self._call_lazy_check
345345

346-
def __call__(self, *args, **kwargs):
347-
# All the logic in `torch.nn.Module.__call__` has been captured by
348-
# `self.forward = self.dynamo_ctx(self._orig_mod.__call__)`, so we
349-
# override here to avoid running that logic again by default.
350-
return self.forward(*args, **kwargs)
351-
352346
def __reduce__(self):
353347
return (self.__class__, (self._orig_mod, self.dynamo_ctx))
354348

0 commit comments

Comments
 (0)
0