8000 [dynamo] Support `delattr` on result of `torch.compile(module)` · pytorch/pytorch@aafe8a6 · GitHub
  • [go: up one dir, main page]

    Skip to content

    Commit aafe8a6

    Browse files
    committed
    [dynamo] Support delattr on result of torch.compile(module)
    This is essentially a follow-up on #122098, where we added support of `getattr` and `setattr` on result of `torch.compile(module)`, but didn't add support for `delattr`. Fixes #150711. ghstack-source-id: 1bc044f Pull Request resolved: #152741
    1 parent a1e368b commit aafe8a6

    File tree

    2 files changed

    +22
    -0
    lines changed

    2 files changed

    +22
    -0
    lines changed

    test/dynamo/test_modules.py

    Lines changed: 13 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -3047,6 +3047,19 @@ def forward(self, inp):
    30473047

    30483048
    self.assertEqual(model.x, compiled_model.x)
    30493049

    3050+
    def test_delattr_on_compiled_module(self):
    3051+
    class Mod(torch.nn.Module):
    3052+
    def forward(self, x):
    3053+
    return x + 1
    3054+
    3055+
    model = Mod()
    3056+
    compiled_model = torch.compile(model)
    3057+
    compiled_model.foo = 42
    3058+
    del compiled_model.foo
    3059+
    3060+
    self.assertFalse(hasattr(model, "foo"))
    3061+
    self.assertFalse(hasattr(compiled_model, "foo"))
    3062+
    30503063
    def test_globals_change_in_other_file(self):
    30513064
    @torch.compile(backend="eager", fullgraph=True)
    30523065
    def fn(x):

    torch/_dynamo/eval_frame.py

    Lines changed: 9 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -389,6 +389,15 @@ def __setattr__(self, name, val) -> None:
    389389
    return super().__setattr__(name, val)
    390390
    return setattr(self._orig_mod, name, val)
    391391

    392+
    def __delattr__(self, name):
    393+
    # This mirrors `__setattr__`
    394+
    if hasattr(type(self), name):
    395+
    return super().__delattr__(name)
    396+
    397+
    if name in OptimizedModule._opt_mod_attributes:
    398+
    return super().__delattr__(name)
    399+
    return delattr(self._orig_mod, name)
    400+
    392401
    def _call_lazy_check(self, *args, **kwargs):
    393402
    if (
    394403
    hasattr(self._orig_mod, "_initialize_hook")

    0 commit comments

    Comments
     (0)
    0