8000 [dynamo] Support `delattr` on result of `torch.compile(module)` (#152… · pytorch/pytorch@6c025b5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6c025b5

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Support delattr on result of torch.compile(module) (#152741)
This is essentially a follow-up on #122098, where we added support of `getattr` and `setattr` on result of `torch.compile(module)`, but didn't add support for `delattr`. Fixes #150711. Pull Request resolved: #152741 Approved by: https://github.com/anijain2305 ghstack dependencies: #152740
1 parent 0886d40 commit 6c025b5

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

test/dynamo/test_modules.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3047,6 +3047,19 @@ def forward(self, inp):
30473047

30483048
self.assertEqual(model.x, compiled_model.x)
30493049

3050+
def test_delattr_on_compiled_module(self):
3051+
class Mod(torch.nn.Module):
3052+
def forward(self, x):
3053+
return x + 1
3054+
3055+
model = Mod()
3056+
compiled_model = torch.compile(model)
3057+
compiled_model.foo = 42
3058+
del compiled_model.foo
3059+
3060+
self.assertFalse(hasattr(model, "foo"))
3061+
self.assertFalse(hasattr(compiled_model, "foo"))
3062+
30503063
def test_globals_change_in_other_file(self):
30513064
@torch.compile(backend="eager", fullgraph=True)
30523065
def fn(x):

torch/_dynamo/eval_frame.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,15 @@ def __setattr__(self, name, val) -> None:
389389
return super().__setattr__(name, val)
390390
return setattr(self._orig_mod, name, val)
391391

392+
def __delattr__(self, name):
393+
# This mirrors `__setattr__`
394+
if hasattr(type(self), name):
395+
return super().__delattr__(name)
396+
397+
if name in OptimizedModule._opt_mod_attributes:
398+
return super().__delattr__(name)
399+
return delattr(self._orig_mod, name)
400+
392401
def _call_lazy_check(self, *args, **kwargs):
393402
if (
394403
hasattr(self._orig_mod, "_initialize_hook")

0 commit comments

Comments
 (0)
0