8000 Revert "[dynamo] Make `OptimizedModule` more robust in attribute read… · pytorch/pytorch@c2dda47 · GitHub
[go: up one dir, main page]

Skip to content

Commit c2dda47

Browse files
Revert "[dynamo] Make OptimizedModule more robust in attribute reads and writes (#153637)"
This reverts commit 2ce0b66. Reverted #153637 on behalf of https://github.com/malfet due to Looks like it broke slow tests, see https://hud.pytorch.org/hud/pytorch/pytorch/cda572b053033abc57b3b3358a861cbc71a490b9/1?per_page=50&name_filter=&mergeEphemeralLF=true ([comment](#153637 (comment)))
1 parent cda572b commit c2dda47

File tree

2 files changed

+6
-50
lines changed

2 files changed

+6
-50
lines changed

test/dynamo/test_repros.py

-34
Original file line numberDiff line numberDiff line change
@@ -5853,40 +5853,6 @@ def test_optimized_module_training(self):
58535853
mod.eval()
58545854
self.assertFalse(opt_mod.training)
58555855

5856-
def test_optimized_module_patched_init(self):
5857-
# A regression test for #138157, and the pattern acame from deepspeed.
5858-
class MyModule(torch.nn.Module):
5859-
def __init__(self):
5860-
super().__init__()
5861-
5862-
def forward(self, x):
5863-
return x.mul(5.0)
5864-
5865-
def patch_init(init):
5866-
@functools.wraps(init)
5867-
def wrapper(module, *args, **kwargs):
5868-
if not hasattr(module, "_ds_child_entered"):
5869-
# child's __init__ was called, since parents all see the same object they can now skip post_init
5870-
module._ds_child_entered = True
5871-
init(module, *args, **kwargs)
5872-
5873-
return wrapper
5874-
5875-
def patch_init_for_class(cls):
5876-
if "__init__" in cls.__dict__:
5877-
cls._old_init = cls.__init__
5878-
cls.__init__ = patch_init(cls.__init__)
5879-
5880-
patch_init_for_class(MyModule)
5881-
mod = MyModule()
5882-
opt_mod = torch.compile(mod)
5883-
5884-
x = torch.rand(10)
5885-
ref = mod(x)
5886-
res = opt_mod(x)
5887-
5888-
self.assertEqual(ref, res)
5889-
58905856
def test_os_fspath(self):
58915857
@torch.compile(backend="eager", fullgraph=True)
58925858
def fn(x):

torch/_dynamo/eval_frame.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -313,23 +313,12 @@ class OptimizedModule(torch.nn.Module):
313313
"_forward",
314314
"__dict__",
315315
"named_children_walk",
316-
"_super_module_initialized",
317316
}
318317

319318
def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None:
320-
# NOTE: this must go first, because attribute reads/writes of `self`
321-
# uses `_orig_mod`, and sometimes users override `Module.__init__` to
322-
# do attribute reads/writes on `self`.
323-
#
324-
# We also can't use regular setattr because `super().__setattr__` will
325-
# complain for module value before `super().__init__()`
326-
object.__setattr__(self, "_orig_mod", mod)
327-
self._super_module_initialized = False
328319
super().__init__()
329-
self._super_module_initialized = True
330-
331320
# Installs the params/buffer
332-
self._orig_mod = mod # `super().__setattr__` will register this module
321+
self._orig_mod = mod
333322
self.dynamo_ctx = dynamo_ctx
334323
self._initialize()
335324
self.training = self._orig_mod.training
@@ -390,11 +379,12 @@ def training(self):
390379

391380
@training.setter
392381
def training(self, value):
393-
# Ignore the `training` mutation in `super().__init__()`, since that's
394-
# setting the default on `nn.Module`, but we are mirroring the
395-
# `training` attr in `self._orig_mod`.
396-
if self._super_module_initialized:
382+
try:
383+
super().__getattr__("_orig_mod")
397384
self._orig_mod.training = value
385+
except AttributeError:
386+
# still initializing
387+
pass
398388

399389
def __getattr__(self, name):
400390
if name == "_orig_mod":

0 commit comments

Comments
 (0)
0