8000 [dynamo] support custom __getattr__ on torch.nn.Modules · pytorch/pytorch@0485ec4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0485ec4

Browse files
committed
[dynamo] support custom __getattr__ on torch.nn.Modules
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__` Example of a module which previously would fail: ```python class MyMod(torch.nn.Module): def __init__(self): super().__init__() self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]} self.other_attr = torch.rand((2, 2)) def __getattr__(self, name): custom_dict = self.custom_dict if name in custom_dict: return custom_dict[name] return super().__getattr__(name) def forward(self, x): return x @ self.other_attr + self.queue[-1] ``` ghstack-source-id: da0a0d8 Pull Request resolved: #94658
1 parent 948cd61 commit 0485ec4

File tree

4 files changed

+89
-19
lines changed

4 files changed

+89
-19
lines changed

test/dynamo/test_misc.py

Lines changed: 24 additions & 0 deletions
8000
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,30 @@ def fn(cfg, x):
10101010
self.assertEqual(cnts.frame_count, 1)
10111011
self.assertEqual(cnts.op_count, 3)
10121012

1013+
def test_nn_module_getattr(self):
1014+
class MyMod(torch.nn.Module):
1015+
def __init__(self):
1016+
super().__init__()
1017+
self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
1018+
self.other_attr = torch.rand((2, 2))
1019+
1020+
def __getattr__(self, name):
1021+
custom_dict = self.custom_dict
1022+
if name in custom_dict:
1023+
return custom_dict[name]
1024+
return super().__getattr__(name)
1025+
1026+
def forward(self, x):
1027+
return x @ self.other_attr + self.queue[-1]
1028+
1029+
x = torch.rand((2, 2))
1030+
mod = MyMod()
1031+
cnts = torch._dynamo.testing.CompileCounter()
1032+
opt_mod = torch._dynamo.optimize(cnts)(mod)
1033+
self.assertTrue(same(opt_mod(x), mod(x)))
1034+
self.assertTrue(cnts.frame_count, 1)
1035+
self.assertTrue(cnts.op_count, 2)
1036+
10131037
def test_user_property(self):
10141038
class MyConfig:
10151039
@property

torch/_dynamo/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,3 +1301,26 @@ def import_submodule(mod: types.ModuleType):
13011301
for filename in sorted(os.listdir(os.path.dirname(mod.__file__))):
13021302
if filename.endswith(".py") and filename[0] != "_":
13031303
importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
1304+
1305+
1306+
def object_has_getattribute(value: Any):
1307+
try:
1308+
if isinstance(
1309+
inspect.getattr_static(type(value), "__getattribute__"),
1310+
types.FunctionType,
1311+
):
1312+
return True
1313+
except AttributeError:
1314+
pass
1315+
return False
1316+
1317+
1318+
def get_custom_getattr(value: Any):
1319+
try:
1320+
getattr_fn = inspect.getattr_static(type(value), "__getattr__")
1321+
except AttributeError:
1322+
getattr_fn = None
1323+
if getattr_fn is torch.nn.Module.__getattr__:
1324+
# ignore this case of getattr
1325+
getattr_fn = None
1326+
return getattr_fn

torch/_dynamo/variables/nn_module.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from ..mutation_guard import GenerationTracker
1515
from ..source import AttrSource, GetItemSource, NNModuleSource, NotNNModuleSource
1616
from ..utils import (
17+
get_custom_getattr,
1718
is_lazy_module,
1819
is_safe_constant,
1920
istensor,
2021
istype,
22+
object_has_getattribute,
2123
proxy_args_kwargs,
2224
)
2325
from .base import MutableLocal, typestr, VariableTracker
@@ -86,6 +88,25 @@ def convert_to_unspecialized(self, tx):
8688
GenerationTracker.mark_class_dynamic(type(mod))
8789
raise RestartAnalysis()
8890

91+
def _custom_getattr_fallback(self, base, tx, name, options):
92+
"""Check for a __getattr__ and handle it specially if it is implemented"""
93+
if object_has_getattribute(base):
94+
unimplemented("torch.nn.Module with a custom __getattribute__ defined")
95+
96+
getattr_fn = get_custom_getattr(base)
97+
if getattr_fn is None:
98+
return None
99+
100+
if not isinstance(getattr_fn, types.FunctionType):
101+
unimplemented("torch.nn.Module with a non-function custom __getattr__")
102+
103+
return variables.UserMethodVariable(
104+
getattr_fn,
105+
self,
106+
**options
107+
# getattr_fn, self, **options
108+
).call_function(tx, [variables.ConstantVariable(name)], {})
109+
89110
def var_getattr(self, tx, name):
90111
from .builder import VariableBuilder
91112

@@ -121,8 +142,18 @@ def var_getattr(self, tx, name):
23D3 121142
elif "_buffers" in base_dict and name in base_dict["_buffers"]:
122143
subobj = base_dict["_buffers"][name]
123144
else:
124-
subobj = inspect.getattr_static(base, name)
125-
object_member = False
145+
try:
146+
subobj = inspect.getattr_static(base, name)
147+
object_member = False
148+
except AttributeError as e:
149+
# see if we can fallback to __getattr__, which is not checked by getattr_static
150+
result = self._custom_getattr_fallback(
151+
base=base, tx=tx, name=name, options=options
152+
)
153+
if result is not None:
154+
return result
155+
# if we can't find a __getattr__, just raise the AttributeError
156+
raise e
126157

127158
if name == "__class__" and not object_member:
128159
return variables.UserDefinedClassVariable(base.__class__, **options)

torch/_dynamo/variables/user_defined.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from ..exc import unimplemented
1414
from ..guards import GuardBuilder
1515
from ..source import AttrSource, ODictGetItemSource, RandomValueSource
16-
from ..utils import is_namedtuple_cls, namedtuple_fields
16+
from ..utils import (
17+
get_custom_getattr,
18+
is_namedtuple_cls,
19+
namedtuple_fields,
20+
object_has_getattribute,
21+
)
1722
from .base import MutableLocal, VariableTracker
1823
from .misc import NullContextVariable
1924

@@ -264,24 +269,11 @@ def call_function(
264269
return super().call_function(tx, args, kwargs)
265270

266271
def _check_for_getattribute(self):
267-
try:
268-
if isinstance(
269-
inspect.getattr_static(type(self.value), "__getattribute__"),
270-
types.FunctionType,
271-
):
272-
unimplemented("UserDefinedObjectVariable with custom __getattribute__")
273-
except AttributeError:
274-
pass
272+
if object_has_getattribute(self.value):
273+
unimplemented("UserDefinedObjectVariable with custom __getattribute__")
275274

276275
def _check_for_getattr(self):
277-
try:
278-
getattr_fn = inspect.getattr_static(type(self.value), "__getattr__")
279-
except AttributeError:
280-
getattr_fn = None
281-
if getattr_fn is torch.nn.Module.__getattr__:
282-
# ignore this case of getattr
283-
getattr_fn = None
284-
return getattr_fn
276+
return get_custom_getattr(self.value)
285277

286278
def _getattr_static(self, name):
287279
if (

0 commit comments

Comments
 (0)
0