8000 Revert "[dynamo][guards] Make class members go through obj.__class__.… · pytorch/pytorch@805a102 · GitHub
[go: up one dir, main page]

Skip to content

Commit 805a102

Browse files
Revert "[dynamo][guards] Make class members go through obj.__class__.__dict__ (#159534)"
This reverts commit 1616777. Reverted #159534 on behalf of https://github.com/malfet due to Broke some inductor test and lint among other things, see https://hud.pytorch.org/hud/pytorch/pytorch/9c18901bfdc526a1df22866904ddcdb4d4ba5394/1?per_page=50&mergeEphemeralLF=true ([comment](#159534 (comment)))
1 parent 6e8d705 commit 805a102

File tree

9 files changed

+27
-290
lines changed

9 files changed

+27
-290
lines changed

test/dynamo/test_guard_manager.py 10BC0

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -880,9 +880,8 @@ def hook(guard_wrapper, f_locals, builder):
880880
counter += 1
881881

882882
class Bar:
883-
def __init__(self):
884-
self.x = 4
885-
self.y = torch.randn(4)
883+
x = 4
884+
y = torch.randn(4)
886885

887886
bar = Bar()
888887

test/dynamo/test_skip_guard_eval_unsafe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ def fn(x, y):
5454

5555
def test_post_recompile(self):
5656
class Foo:
57-
def __init__(self):
58-
self.a = 4
59-
self.b = 5
57+
a = 4
58+
b = 5
6059

6160
foo = Foo()
6261

torch/_C/_dynamo/guards.pyi

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,6 @@ class DictGuardManager(GuardManager):
139139
class GuardAccessor: ...
140140
class DictGetItemGuardAccessor(GuardAccessor): ...
141141
class GetGenericDictGuardAccessor(GuardAccessor): ...
142-
class TypeDictGuardAccessor(GuardAccessor): ...
143-
class TypeMROGuardAccessor(GuardAccessor): ...
144142

145143
def install_object_aliasing_guard(
146144
guard_managers: list[GuardManager],

torch/_dynamo/guards.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,6 @@
132132
TorchFunctionModeStackSource,
133133
TorchSource,
134134
TupleIteratorGetItemSource,
135-
TypeDictSource,
136-
TypeMROSource,
137135
TypeSource,
138136
UnspecializedBuiltinNNModuleSource,
139137
UnspecializedNNModuleSource,
@@ -864,9 +862,6 @@ def __init__(
864862
self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
865863
"pytorch/compiler:guard_nn_modules"
866864
)
867-
self.already_guarded_not_present_in_generic_dict: OrderedSet[
868-
tuple[str, str]
869-
] = OrderedSet()
870865

871866
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
872867
dict_mgr = self.get_guard_manager(guard)
@@ -1214,20 +1209,6 @@ def get_guard_manager_from_source(self, source):
12141209
example_value=example_value,
12151210
guard_manager_enum=guard_manager_enum,
12161211
)
1217-
elif istype(source, TypeDictSource):
1218-
assert base_guard_manager # to make mypy happy
1219-
out = base_guard_manager.type_dict_manager(
1220-
source=source_name,
1221-
example_value=example_value,
1222-
guard_manager_enum=guard_manager_enum,
1223-
)
1224-
elif istype(source, TypeMROSource):
1225-
assert base_guard_manager # to make mypy happy
1226-
out = base_guard_manager.type_mro_manager(
1227-
source=source_name,
1228-
example_value=example_value,
1229-
guard_manager_enum=guard_manager_enum,
1230-
)
12311212
elif istype(
12321213
source,
12331214
(
@@ -1653,12 +1634,10 @@ def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None:
16531634
assert attr is not None
16541635
ref = self.arg_ref(guard)
16551636
val = self.get(guard.name)
1637+
assert isinstance(val, torch.nn.Module)
16561638

16571639
base_manager = self.get_guard_manager(guard)
16581640

1659-
if (ref, attr) in self.already_guarded_not_present_in_generic_dict:
1660-
return
1661-
16621641
mod_dict_source = f"{guard.name}.__dict__"
16631642
mod_generic_dict_manager = base_manager.get_generic_dict_manager(
16641643
source=mod_dict_source,
@@ -1670,7 +1649,6 @@ def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None:
16701649
mod_generic_dict_manager.add_dict_contains_guard(
16711650
False, attr, get_verbose_code_parts(code, guard)
16721651
)
1673-
self.already_guarded_not_present_in_generic_dict.add((ref, attr))
16741652

16751653
def TYPE_MATCH(self, guard: Guard) -> None:
16761654
# ___check_type_id is same as `id(type(x)) == y`

torch/_dynamo/source.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -266,38 +266,6 @@ def name(self) -> str:
266266
return f"object.__getattribute__({self.base.name()}, {self.member!r})"
267267

268268

269-
# Represents obj.__dict__ where obj is a type object
270-
@dataclasses.dataclass(frozen=True)
271-
class TypeDictSource(ChainedSource):
272-
def reconstruct(self, codegen: "PyCodegen"):
273-
codegen(self.base)
274-
codegen.extend_output(codegen.create_load_attrs("__dict__"))
275-
276-
def guard_source(self):
277-
return self.base.guard_source()
278-
279-
def name(self):
280-
# type(ob).__dict__ can return a proxy of the dict. But in the C++
281-
# guard accessor, we are use type->tp_dict which is a dict. So,
282-
# forcefully pass a dict object to ensure that the GuardManager
283-
# registers that its working on a dict object.
284-
return f"dict({self.base.name()}.__dict__)"
285-
286-
287-
# Represents obj.__mro__ where object is type object
288-
@dataclasses.dataclass(frozen=True)
289-
class TypeMROSource(ChainedSource):
290-
def reconstruct(self, codegen: "PyCodegen"):
291-
codegen(self.base)
292-
codegen.extend_output(codegen.create_load_attrs("__mro__"))
293-
294-
def guard_source(self):
295-
return self.base.guard_source()
296-
297-
def name(self):
298-
return f"{self.base.name()}.__mro__"
299-
300-
301269
@dataclasses.dataclass(frozen=True)
302270
class LocalCellSource(Source):
303271
"""

torch/_dynamo/variables/misc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
AttrSource,
4343
GenericAttrSource,
4444
GetItemSource,
45-
TypeMROSource,
4645
TypeSource,
4746
WeakRefCallSource,
4847
)
@@ -135,7 +134,9 @@ def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name):
135134
# Equivalent of something like type(L['self']).__mro__[1].attr_name
136135
if type_to_use_source:
137136
source = AttrSource(
138-
GetItemSource(TypeMROSource(type_to_use_source), index),
137+
GetItemSource(
138+
AttrSource(type_to_use_source, "__mro__"), index
139+
),
139140
name,
140141
)
141142
return resolved_getattr, source
@@ -246,7 +247,7 @@ def call_method(
246247
# different from type(self) with polymorphism.
247248
cls_source = None
248249
if self.objvar.source:
249-
cls_source = TypeSource(self.objvar.source)
250+
cls_source = AttrSource(self.objvar.source, "__class__")
250251
cls_variable = VariableTracker.build(
251252
tx, self.objvar.value_type, cls_source
252253
)

torch/_dynamo/variables/nn_module.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ def call_function(
989989
fn = self.value_type.forward
990990

991991
if self.source:
992-
source = self.get_source_by_walking_mro(name)
992+
source = AttrSource(AttrSource(self.source, "__class__"), name)
993993
else:
994994
source = None
995995

@@ -1017,7 +1017,7 @@ def call_method(
10171017
if name in ["_call_impl", "_wrapped_call_impl"]:
10181018
fn = getattr(self.value_type, name)
10191019
if self.source:
1020-
source = self.get_source_by_walking_mro(name)
1020+
source = AttrSource(AttrSource(self.source, "__class__"), name)
10211021
else:
10221022
source = None
10231023

@@ -1032,7 +1032,9 @@ def call_method(
10321032
method = None
10331033

10341034
if isinstance(method, staticmethod):
1035-
source = AttrSource(self.get_source_by_walking_mro(name), "__func__")
1035+
source = AttrSource(
1036+
AttrSource(AttrSource(self.source, "__class__"), name), "__func__"
1037+
)
10361038
return tx.inline_user_function_return(
10371039
variables.UserFunctionVariable(method.__func__, source=source),
10381040
args,

torch/_dynamo/variables/user_defined.py

Lines changed: 13 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,8 @@
6060
AttrSource,
6161
CallFunctionNoArgsSource,
6262
DataclassFieldsSource,
63-
DictGetItemSource,
6463
GetItemSource,
6564
RandomValueSource,
66-
TypeDictSource,
67-
TypeMROSource,
6865
TypeSource,
6966
UnspecializedParamBufferSource,
7067
)
@@ -1001,9 +998,11 @@ def call_method(
1001998

1002999
# check for methods implemented in C++
10031000
if isinstance(method, types.FunctionType):
1004-
source = None
1005-
if self.source:
1006-
source = self.get_source_by_walking_mro(name)
1001+
source = (
1002+
None
1003+
if self.source is None
1004+
else AttrSource(AttrSource(self.source, "__class__"), name)
1005+
)
10071006
# TODO(jansel): add a guard to check for monkey patching?
10081007
from ..mutation_guard import unpatched_nn_module_init
10091008

@@ -1225,40 +1224,12 @@ def get_source_by_walking_mro(self, name):
12251224

12261225
for idx, klass in enumerate(type(self.value).__mro__):
12271226
if name in klass.__dict__:
1228-
if idx != 0:
1229-
mro_source = TypeMROSource(self.cls_source)
1230-
klass_source = GetItemSource(mro_source, idx)
1231-
else:
1232-
klass_source = self.cls_source
1233-
dict_source = TypeDictSource(klass_source)
1234-
out_source = DictGetItemSource(dict_source, name)
1235-
1236-
for absent_idx in range(1, idx):
1237-
# Insert a guard that the name is not present in the mro hierarchy
1238-
mro_source = TypeMROSource(self.cls_source)
1239-
klass_source = GetItemSource(mro_source, absent_idx)
1240-
dict_source = TypeDictSource(klass_source)
1241-
install_guard(
1242-
dict_source.make_guard(
1243-
functools.partial(
1244-
GuardBuilder.DICT_CONTAINS, key=name, invert=True
1245-
)
1246-
)
1247-
)
1248-
# Insert a guard that the name is not present in the object __dict__
1249-
if (
1250-
self.source
1251-
and hasattr(self.value, "__dict__")
1252-
and name not in self.value.__dict__
1253-
):
1254-
install_guard(
1255-
self.source.make_guard(
1256-
functools.partial(
1257-
GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr=name
1258-
)
1259-
)
1260-
)
1261-
return out_source
1227+
mro_source = AttrSource(self.cls_source, "__mro__")
1228+
klass_source = GetItemSource(mro_source, idx)
1229+
dict_source = AttrSource(klass_source, "__dict__")
1230+
# TODO(anijain2305) - This is a mapping proxy object. Ideally we
1231+
# should use DictGetItemSource here.
1232+
return GetItemSource(dict_source, name)
12621233

12631234
unimplemented_v2(
12641235
gb_type="could not find name in object's mro",
@@ -1368,17 +1339,10 @@ def var_getattr(self, tx: "InstructionTranslator", name):
13681339
if subobj is torch.nn.Module.__init__:
13691340
subobj = unpatched_nn_module_init
13701341

1371-
subobj_from_class = inspect.getattr_static(
1372-
self.value.__class__, name, NO_SUCH_SUBOBJ
1373-
)
1374-
is_accessible_from_type_mro = (
1375-
subobj_from_class is subobj and self.cls_source is not None
1376-
)
1377-
13781342
if isinstance(subobj, property):
13791343
if self.source:
13801344
# Read the class attribute to reach the property
1381-
source = self.get_source_by_walking_mro(name)
1345+
source = AttrSource(AttrSource(self.source, "__class__"), name)
13821346
# Get the getter function
13831347
source = AttrSource(source, "fget")
13841348
return variables.UserMethodVariable(
@@ -1396,11 +1360,6 @@ def var_getattr(self, tx: "InstructionTranslator", name):
13961360
# Safe because `staticmethod.__get__` basically won't trigger user
13971361
# code and just returns the underlying `__func__`:
13981362
# https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100
1399-
if is_accessible_from_type_mro:
1400-
# Accessing from __dict__ does not resolve the descriptor, it
1401-
# returns a staticmethod object, so access the __func__
1402-
# attribute to get to the actual function.
1403-
source = AttrSource(self.get_source_by_walking_mro(name), "__func__")
14041363
func = subobj.__get__(self.value)
14051364
return VariableTracker.build(tx, func, source)
14061365
elif isinstance(subobj, classmethod):
@@ -1526,15 +1485,10 @@ def var_getattr(self, tx: "InstructionTranslator", name):
15261485
source = self._wrap_source(source)
15271486

15281487
if subobj is not NO_SUCH_SUBOBJ:
1529-
if is_wrapper_or_member_descriptor(
1530-
subobj
1531-
) or torch._C._dynamo.utils.is_instancemethod(subobj):
1488+
if is_wrapper_or_member_descriptor(subobj):
15321489
options = {"source": source}
15331490
return variables.GetAttrVariable(self, name, **options)
15341491
if source:
1535-
if is_accessible_from_type_mro:
1536-
source = self.get_source_by_walking_mro(name)
1537-
15381492
return variables.LazyVariableTracker.create(subobj, source)
15391493
else:
15401494
# Check if the subobj is accessible from the class itself. If the class source is known, we can create a

0 commit comments

Comments
 (0)
0