8000 Update on "[dynamo] support custom __getattr__ on torch.nn.Modules" · pytorch/pytorch@1c78962 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1c78962

Browse files
committed
Update on "[dynamo] support custom __getattr__ on torch.nn.Modules"
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__` Example of a module which previously would fail: ```python class MyMod(torch.nn.Module): def __init__(self): super().__init__() self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]} self.other_attr = torch.rand((2, 2)) def __getattr__(self, name): custom_dict = self.custom_dict if name in custom_dict: return custom_dict[name] return super().__getattr__(name) def forward(self, x): return x @ self.other_attr + self.queue[-1] ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
2 parents 915026f + ba53598 commit 1c78962

File tree

1 file changed

+0
-2
lines changed

1 file changed

+0
-2
lines changed

test/dynamo/test_misc.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,6 @@ def fn(obj, x):
10331033
cnts = torch._dynamo.testing.CompileCounter()
10341034
opt_fn = torch._dynamo.optimize(cnts)(fn)
10351035
self.assertTrue(same(opt_fn(obj, x), fn(obj, x)))
1036-
self.assertTrue(cnts.frame_count <= 2)
10371036

10381037
def test_nn_module_getattr(self):
10391038
class MyMod(torch.nn.Module):
@@ -1081,7 +1080,6 @@ def fn(mod, x):
10811080
cnts = torch._dynamo.testing.CompileCounter()
10821081
opt_fn = torch._dynamo.optimize(cnts)(fn)
10831082
self.assertTrue(same(opt_fn(mod, x), fn(mod, x)))
1084-
self.assertTrue(cnts.frame_count <= 2)
10851083

10861084
def test_user_property(self):
10871085
class MyConfig:

0 commit comments

Comments
 (0)
0