8000 [dynamo] support custom __getattr__ on torch.nn.Modules (#94658) · kulinseth/pytorch@f7b9686 · GitHub
[go: up one dir, main page]

Skip to content

Commit f7b9686

Browse files
davidberard98jhavukainen
authored andcommitted
[dynamo] support custom __getattr__ on torch.nn.Modules (pytorch#94658)
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__` Example of a module which previously would fail: ```python class MyMod(torch.nn.Module): def __init__(self): super().__init__() self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]} self.other_attr = torch.rand((2, 2)) def __getattr__(self, name): custom_dict = self.custom_dict if name in custom_dict: return custom_dict[name] return super().__getattr__(name) def forward(self, x): return x @ self.other_attr + self.queue[-1] ``` Pull Request resolved: pytorch#94658 Approved by: https://github.com/yanboliang, https://github.com/jansel
1 parent 4b70206 commit f7b9686

File tree

4 files changed

+133
-19
lines changed

4 files changed

+133
-19
lines changed

test/dynamo/test_misc.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,77 @@ def fn(cfg, x):
10101010
self.assertEqual(cnts.frame_count, 1)
10111011
self.assertEqual(cnts.op_count, 3)
10121012

1013+
def test_user_getattribute(self):
1014+
class MyObject:
1015+
def __init__(self):
1016+
self.custom_dict = {"a": torch.rand((2, 2))}
1017+
self.my_number = 42
1018+
1019+
def __getattribute__(self, name):
1020+
custom_dict = super().__getattribute__("custom_dict")
1021+
if name in custom_dict:
1022+
return custom_dict[name]
1023+
return super().__getattribute__(name)
1024+
1025+
def run(self, x):
1026+
return self.my_number * x + self.a * x
1027+
1028+
def fn(obj, x):
1029+
return obj.run(x)
1030+
1031+
obj = MyObject()
1032+
x = torch.rand((2, 2))
1033+
cnts = torch._dynamo.testing.CompileCounter()
1034+
opt_fn = torch._dynamo.optimize(cnts)(fn)
1035+
self.assertTrue(same(opt_fn(obj, x), fn(obj, x)))
1036+
1037+
def test_nn_module_getattr(self):
1038+
class MyMod(torch.nn.Module):
1039+
def __init__(self):
1040+
super().__init__()
1041+
self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
1042+
self.other_attr = torch.rand((2, 2))
1043+
1044+
def __getattr__(self, name):
1045+
custom_dict = self.custom_dict
1046+
if name in custom_dict:
1047+
return custom_dict[name]
1048+
return super().__getattr__(name)
1049+
1050+
def forward(self, x):
1051+
return x @ self.other_attr + self.queue[-1]
1052+
1053+
x = torch.rand((2, 2))
1054+
mod = MyMod()
1055+
cnts = torch._dynamo.testing.CompileCounter()
1056+
opt_mod = torch._dynamo.optimize(cnts)(mod)
1057+
self.assertTrue(same(opt_mod(x), mod(x)))
1058+
self.assertTrue(cnts.frame_count, 1)
1059+
self.assertTrue(cnts.op_count, 2)
1060+
1061+
def test_nn_module_getattribute(self):
1062+
class MyMod(torch.nn.Module):
1063+
def __init__(self):
1064+
super().__init__()
1065+
self.my_number = 42
1066+
1067+
def __getattribute__(self, name):
1068+
if name == "special_attr":
1069+
return torch.tensor([[1, 2], [3, 4]])
1070+
return super().__getattribute__(name)
1071+
1072+
def forward(self, x):
1073+
return self.my_number * x + self.special_attr * x
1074+
1075+
def fn(mod, x):
1076+
return mod(x)
1077+
1078+
mod = MyMod()
1079+
x = torch.rand((2, 2))
1080+
cnts = torch._dynamo.testing.CompileCounter()
1081+
opt_fn = torch._dynamo.optimize(cnts)(fn)
1082+
self.assertTrue(same(opt_fn(mod, x), fn(mod, x)))
1083+
10131084
def test_user_property(self):
10141085
class MyConfig:
10151086
@property

torch/_dynamo/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,3 +1302,26 @@ def import_submodule(mod: types.ModuleType):
13021302
for filename in sorted(os.listdir(os.path.dirname(mod.__file__))):
13031303
if filename.endswith(".py") and filename[0] != "_":
13041304
importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
1305+
1306+
1307+
def object_has_getattribute(value: Any):
1308+
try:
1309+
if isinstance(
1310+
inspect.getattr_static(type(value), "__getattribute__"),
1311+
types.FunctionType,
1312+
):
1313+
return True
1314+
except AttributeError:
1315+
pass
1316+
return False
1317+
1318+
1319+
def get_custom_getattr(value: Any):
1320+
try:
1321+
getattr_fn = inspect.getattr_static(type(value), "__getattr__")
1322+
except AttributeError:
1323+
getattr_fn = None
1324+
if getattr_fn is torch.nn.Module.__getattr__:
1325+
# ignore this case of getattr
1326+
getattr_fn = None
1327+
return getattr_fn

torch/_dynamo/variables/nn_module.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from ..mutation_guard import GenerationTracker
1515
from ..source import AttrSource, GetItemSource, NNModuleSource, NotNNModuleSource
1616
from ..utils import (
17+
get_custom_getattr,
1718
is_lazy_module,
1819
is_safe_constant,
1920
istensor,
2021
istype,
22+
object_has_getattribute,
2123
proxy_args_kwargs,
2224
)
2325
from .base import MutableLocal, typestr, VariableTracker
@@ -86,6 +88,22 @@ def convert_to_unspecialized(self, tx):
8688
GenerationTracker.mark_class_dynamic(type(mod))
8789
raise RestartAnalysis()
8890

91+
def _custom_getattr_fallback(self, base, tx, name, options):
92+
"""Check for a __getattr__ and handle it specially if it is implemented"""
93+
if object_has_getattribute(base):
94+
unimplemented("torch.nn.Module with a custom __getattribute__ defined")
95+
96+
getattr_fn = get_custom_getattr(base)
97+
if getattr_fn is None:
98+
return None
99+
100+
if not isinstance(getattr_fn, types.FunctionType):
101+
unimplemented("torch.nn.Module with a non-function custom __getattr__")
102+
103+
return variables.UserMethodVariable(getattr_fn, self, **options).call_function(
104+
tx, [variables.ConstantVariable(name)], {}
105+
)
106+
89107
def var_getattr(self, tx, name):
90108
from .builder import VariableBuilder
91109

@@ -121,8 +139,18 @@ def var_getattr(self, tx, name):
121139
elif "_buffers" in base_dict and name in base_dict["_buffers"]:
122140
subobj = base_dict["_buffers"][name]
123141
else:
124-
subobj = inspect.getattr_static(base, name)
125-
object_member = False
142+
try:
143+
subobj = inspect.getattr_static(base, name)
144+
object_member = False
145+
except AttributeError:
146+
# see if we can fallback to __getattr__, which is not checked by getattr_static
147+
result = self._custom_getattr_fallback(
148+
base=base, tx=tx, name=name, options=options
149+
)
150+
if result is not None:
151+
return result
152+
# if we can't find a __getattr__, just raise the AttributeError
153+
raise
126154

127155
if name == "__class__" and not object_member:
128156
return variables.UserDefinedClassVariable(base.__class__, **options)

torch/_dynamo/variables/user_defined.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from ..exc import unimplemented
1414
from ..guards import GuardBuilder
1515
from ..source import AttrSource, ODictGetItemSource, RandomValueSource
16-
from ..utils import is_namedtuple_cls, namedtuple_fields
16+
from ..utils import (
17+
get_custom_getattr,
18+
is_namedtuple_cls,
19+
namedtuple_fields,
20+
object_has_getattribute,
21+
)
1722
from .base import MutableLocal, VariableTracker
1823
from .misc import NullContextVariable
1924

@@ -264,24 +269,11 @@ def call_function(
264269
return super().call_function(tx, args, kwargs)
265270

266271
def _check_for_getattribute(self):
267-
try:
268-
if isinstance(
269-
inspect.getattr_static(type(self.value), "__getattribute__"),
270-
types.FunctionType,
271-
):
272-
unimplemented("UserDefinedObjectVariable with custom __getattribute__")
273-
except AttributeError:
274-
pass
272+
if object_has_getattribute(self.value):
273+
unimplemented("UserDefinedObjectVariable with custom __getattribute__")
275274

276275
def _check_for_getattr(self):
277-
try:
278-
getattr_fn = inspect.getattr_static(type(self.value), "__getattr__")
279-
except AttributeError:
280-
getattr_fn = None
281-
if getattr_fn is torch.nn.Module.__getattr__:
282-
# ignore this case of getattr
283-
getattr_fn = None
284-
return getattr_fn
276+
return get_custom_getattr(self.value)
285277

286278
def _getattr_static(self, name):
287279
if (

0 commit comments

Comments
 (0)
0