File tree 2 files changed +0
-32
lines changed 2 files changed +0
-32
lines changed Original file line number Diff line number Diff line change @@ -889,32 +889,6 @@ def forward(self, x):
889
889
self .assertEqual (ref , res )
890
890
self .assertEqual (cnts .frame_count , 1 )
891
891
892
- def test_global_module_forward_pre_hook (self ):
893
- class Mod (torch .nn .Module ):
894
- def forward (self , x ):
895
- return x - 1
896
-
897
- counter = 0
898
-
899
- def hook (mod , args ):
900
- nonlocal counter
901
- counter += 1
902
- return (args [0 ] + 100 ,)
903
-
904
- mod = Mod ()
905
- torch .nn .modules .module .register_module_forward_pre_hook (hook )
906
-
907
- # Case 1: torch.compile(mod)
908
- compiled_mod = torch .compile (mod , backend = "eager" )
909
-
910
- x = torch .rand (18 , 18 )
911
-
912
- ref = mod (x )
913
- self .assertEqual (counter , 1 )
914
- res = compiled_mod (x )
915
- self .assertEqual (counter , 2 )
916
- self .assertEqual (ref , res )
917
-
918
892
919
893
if __name__ == "__main__" :
920
894
from torch ._dynamo .test_case import run_tests
Original file line number Diff line number Diff line change @@ -343,12 +343,6 @@ def _initialize(self):
343
343
self ._forward = self .forward
344
344
self .forward = self ._call_lazy_check
345
345
346
- def __call__ (self , * args , ** kwargs ):
347
- # All the logic in `torch.nn.Module.__call__` has been captured by
348
- # `self.forward = self.dynamo_ctx(self._orig_mod.__call__)`, so we
349
- # override here to avoid running that logic again by default.
350
- return self .forward (* args , ** kwargs )
351
-
352
346
def __reduce__ (self ):
353
347
return (self .__class__ , (self ._orig_mod , self .dynamo_ctx ))
354
348
You can’t perform that action at this time.
0 commit comments