@@ -313,23 +313,12 @@ class OptimizedModule(torch.nn.Module):
313
313
"_forward" ,
314
314
"__dict__" ,
315
315
"named_children_walk" ,
316
- "_super_module_initialized" ,
317
316
}
318
317
319
318
def __init__ (self , mod : torch .nn .Module , dynamo_ctx ) -> None :
320
- # NOTE: this must go first, because attribute reads/writes of `self`
321
- # uses `_orig_mod`, and sometimes users override `Module.__init__` to
322
- # do attribute reads/writes on `self`.
323
- #
324
- # We also can't use regular setattr because `super().__setattr__` will
325
- # complain for module value before `super().__init__()`
326
- object .__setattr__ (self , "_orig_mod" , mod )
327
- self ._super_module_initialized = False
328
319
super ().__init__ ()
329
- self ._super_module_initialized = True
330
-
331
320
# Installs the params/buffer
332
- self ._orig_mod = mod # `super().__setattr__` will register this module
321
+ self ._orig_mod = mod
333
322
self .dynamo_ctx = dynamo_ctx
334
323
self ._initialize ()
335
324
self .training = self ._orig_mod .training
@@ -390,11 +379,12 @@ def training(self):
390
379
391
380
@training .setter
392
381
def training (self , value ):
393
- # Ignore the `training` mutation in `super().__init__()`, since that's
394
- # setting the default on `nn.Module`, but we are mirroring the
395
- # `training` attr in `self._orig_mod`.
396
- if self ._super_module_initialized :
382
+ try :
383
+ super ().__getattr__ ("_orig_mod" )
397
384
self ._orig_mod .training = value
385
+ except AttributeError :
386
+ # still initializing
387
+ pass
398
388
399
389
def __getattr__ (self , name ):
400
390
if name == "_orig_mod" :
0 commit comments