8000 Revert "[dynamo] support custom __getattr__ on torch.nn.Modules (#946… · ROCm/pytorch@ba5ef67 · GitHub
[go: up one dir, main page]

Skip to content

Commit ba5ef67

Browse files
committed
Revert "[dynamo] support custom __getattr__ on torch.nn.Modules (pytorch#94658)"
This reverts commit a4085ab.
1 parent ee6c915 commit ba5ef67

File tree

4 files changed

+19
-133
lines changed

4 files changed

+19
-133
lines changed

test/dynamo/test_misc.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,77 +1010,6 @@ def fn(cfg, x):
10101010
self.assertEqual(cnts.frame_count, 1)
10111011
self.assertEqual(cnts.op_count, 3)
10121012

1013-
def test_user_getattribute(self):
1014-
class MyObject:
1015-
def __init__(self):
1016-
self.custom_dict = {"a": torch.rand((2, 2))}
1017-
self.my_number = 42
1018-
1019-
def __getattribute__(self, name):
1020-
custom_dict = super().__getattribute__("custom_dict")
1021-
if name in custom_dict:
1022-
return custom_dict[name]
1023-
return super().__getattribute__(name)
1024-
1025-
def run(self, x):
1026-
return self.my_number * x + self.a * x
1027-
1028-
def fn(obj, x):
1029-
return obj.run(x)
1030-
1031-
obj = MyObject()
1032-
x = torch.rand((2, 2))
1033-
cnts = torch._dynamo.testing.CompileCounter()
1034-
opt_fn = torch._dynamo.optimize(cnts)(fn)
1035-
self.assertTrue(same(opt_fn(obj, x), fn(obj, x)))
1036-
1037-
def test_nn_module_getattr(self):
1038-
class MyMod(torch.nn.Module):
1039-
def __init__(self):
1040-
super().__init__()
1041-
self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
1042-
self.other_attr = torch.rand((2, 2))
1043-
1044-
def __getattr__(self, name):
1045-
custom_dict = self.custom_dict
1046-
if name in custom_dict:
1047-
return custom_dict[name]
1048-
return super().__getattr__(name)
1049-
1050-
def forward(self, x):
1051-
return x @ self.other_attr + self.queue[-1]
1052-
1053-
x = torch.rand((2, 2))
1054-
mod = MyMod()
1055-
cnts = torch._dynamo.testing.CompileCounter()
1056-
opt_mod = torch._dynamo.optimize(cnts)(mod)
1057-
self.assertTrue(same(opt_mod(x), mod(x)))
1058-
self.assertTrue(cnts.frame_count, 1)
1059-
self.assertTrue(cnts.op_count, 2)
1060-
1061-
def test_nn_module_getattribute(self):
1062-
class MyMod(torch.nn.Module):
1063-
def __init__(self):
1064-
super().__init__()
1065-
self.my_number = 42
1066-
1067-
def __getattribute__(self, name):
1068-
if name == "special_attr":
1069-
return torch.tensor([[1, 2], [3, 4]])
1070-
return super().__getattribute__(name)
1071-
1072-
def forward(self, x):
1073-
return self.my_number * x + self.special_attr * x
1074-
1075-
def fn(mod, x):
1076-
return mod(x)
1077-
1078-
mod = MyMod()
1079-
x = torch.rand((2, 2))
1080-
cnts = torch._dynamo.testing.CompileCounter()
1081-
opt_fn = torch._dynamo.optimize(cnts)(fn)
1082-
self.assertTrue(same(opt_fn(mod, x), fn(mod, x)))
1083-
10841013
def test_user_property(self):
10851014
class MyConfig:
10861015
@property

torch/_dynamo/utils.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,26 +1302,3 @@ def import_submodule(mod: types.ModuleType):
13021302
for filename in sorted(os.listdir(os.path.dirname(mod.__file__))):
13031303
if filename.endswith(".py") and filename[0] != "_":
13041304
importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
1305-
1306-
1307-
def object_has_getattribute(value: Any):
1308-
try:
1309-
if isinstance(
1310-
inspect.getattr_static(type(value), "__getattribute__"),
1311-
types.FunctionType,
1312-
):
1313-
return True
1314-
except AttributeError:
1315-
pass
1316-
return False
1317-
1318-
1319-
def get_custom_getattr(value: Any):
1320-
try:
1321-
getattr_fn = inspect.getattr_static(type(value), "__getattr__")
1322-
except AttributeError:
1323-
getattr_fn = None
1324-
if getattr_fn is torch.nn.Module.__getattr__:
1325-
# ignore this case of getattr
1326-
getattr_fn = None
1327-
return getattr_fn

torch/_dynamo/variables/nn_module.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414
from ..mutation_guard import GenerationTracker
1515
from ..source import AttrSource, GetItemSource, NNModuleSource, NotNNModuleSource
1616
from ..utils import (
17-
get_custom_getattr,
1817
is_lazy_module,
1918
is_safe_constant,
2019
istensor,
2120
istype,
22-
object_has_getattribute,
2321
proxy_args_kwargs,
2422
)
2523
from .base import MutableLocal, typestr, VariableTracker
@@ -88,22 +86,6 @@ def convert_to_unspecialized(self, tx):
8886
GenerationTracker.mark_class_dynamic(type(mod))
8987
raise RestartAnalysis()
9088

91-
def _custom_getattr_fallback(self, base, tx, name, options):
92-
"""Check for a __getattr__ and handle it specially if it is implemented"""
93-
if object_has_getattribute(base):
94-
unimplemented("torch.nn.Module with a custom __getattribute__ defined")
95-
96-
getattr_fn = get_custom_getattr(base)
97-
if getattr_fn is None:
98-
return None
99-
100-
if not isinstance(getattr_fn, types.FunctionType):
101-
unimplemented("torch.nn.Module with a non-function custom __getattr__")
102-
103-
return variables.UserMethodVariable(getattr_fn, self, **options).call_function(
104-
tx, [variables.ConstantVariable(name)], {}
105-
)
106-
10789
def var_getattr(self, tx, name):
10890
from .builder import VariableBuilder
10991

@@ -139,18 +121,8 @@ def var_getattr(self, tx, name):
139121
elif "_buffers" in base_dict and name in base_dict["_buffers"]:
140122
subobj = base_dict["_buffers"][name]
141123
else:
142-
try:
143-
subobj = inspect.getattr_static(base, name)
144-
object_member = False
145-
except AttributeError:
146-
# see if we can fallback to __getattr__, which is not checked by getattr_static
147-
result = self._custom_getattr_fallback(
148-
base=base, tx=tx, name=name, options=options
149-
)
150-
if result is not None:
151-
return result
152-
# if we can't find a __getattr__, just raise the AttributeError
153-
raise
124+
subobj = inspect.getattr_static(base, name)
125+
object_member = False
154126

155127
if name == "__class__" and not object_member:
156128
return variables.UserDefinedClassVariable(base.__class__, **options)

torch/_dynamo/variables/user_defined.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@
1313
from ..exc import unimplemented
1414
from ..guards import GuardBuilder
1515
from ..source import AttrSource, ODictGetItemSource, RandomValueSource
16-
from ..utils import (
17-
get_custom_getattr,
18-
is_namedtuple_cls,
19-
namedtuple_fields,
20-
object_has_getattribute,
21-
)
16+
from ..utils import is_namedtuple_cls, namedtuple_fields
2217
from .base import MutableLocal, VariableTracker
2318
from .misc import NullContextVariable
2419

@@ -269,11 +264,24 @@ def call_function(
269264
return super().call_function(tx, args, kwargs)
270265

271266
def _check_for_getattribute(self):
272-
if object_has_getattribute(self.value):
273-
unimplemented("UserDefinedObjectVariable with custom __getattribute__")
267+
try:
268+
if isinstance(
269+
inspect.getattr_static(type(self.value), "__getattribute__"),
270+
types.FunctionType,
271+
):
272+
unimplemented("UserDefinedObjectVariable with custom __getattribute__")
273+
except AttributeError:
274+
pass
274275

275276
def _check_for_getattr(self):
276-
return get_custom_getattr(self.value)
277+
try:
278+
getattr_fn = inspect.getattr_static(type(self.value), "__getattr__")
279+
except AttributeError:
280+
getattr_fn = None
281+
if getattr_fn is torch.nn.Module.__getattr__:
282+
# ignore this case of getattr
283+
getattr_fn = None
284+
return getattr_fn
277285

278286
def _getattr_static(self, name):
279287
if (

0 commit comments

Comments
 (0)
0