8000 Optimizer classes not `dill` picklable after using `torch.compile` · Issue #126154 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Optimizer classes not dill picklable after using torch.compile #126154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ringohoffman opened this issue May 14, 2024 · 11 comments
Open

Optimizer classes not dill picklable after using torch.compile #126154

ringohoffman opened this issue May 14, 2024 · 11 comments
Labels
module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ringohoffman
Copy link
Contributor
ringohoffman commented May 14, 2024

🐛 Describe the bug

To get started:

$ pip install torch dill

After you use torch.compile, Optimizer subclasses stop being dill picklable:

import dill
import torch

torch.save(torch.optim.AdamW, "a.pth", pickle_module=dill)  # OK

@torch.compile
def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    """Bias-GeLU fused"""
    x = inp + bias
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

torch.save(torch.optim.AdamW, "fake.txt", pickle_module=dill)  # PicklingError: Can't pickle <built-in function reset_code>: it's not found as torch._C._dynamo.eval_frame.reset_code

This seems to be partly the reason:

opt._init_group = disable(opt._init_group)

I think it boils down to torch._C._dynamo.eval_frame not being importable as a package, but as an attribute of torch._C._dynamo (like an imported module):

import dill
import torch._C._dynamo

print(torch._C._dynamo.eval_frame)  # <module 'torch._C._dynamo.eval_frame'>
print(torch._C._dynamo.eval_frame._CacheEntry)  # <class 'torch._C._dynamo.eval_frame._CacheEntry'>

torch.save(torch._C._dynamo.eval_frame._CacheEntry, "a.pth", pickle_module=dill)  # PicklingError: Can't pickle <class 'torch._C._dynamo.eval_frame._CacheEntry'>: it's not found as torch._C._dynamo.eval_frame._CacheEntry

import torch._C._dynamo.eval_frame  # ModuleNotFoundError: No module named 'torch._C._dynamo.eval_frame'; 'torch._C._dynamo' is not a package

eval_frame is defined here:

auto m = py::handle(eval_frame).cast<py::module>();

Is there some change to dill or C++-generated modules that solve this?

Versions

[conda] torch                     2.3.0                    pypi_0    pypi

cc @mruberry @mikaylagawarecki @msaroufim @ezyang @bdhirsh @anijain2305 @chauhang

@ringohoffman ringohoffman changed the title Optimizer classes not dill picklable after using torch.compile Optimizer classes not dill picklable after using torch.compile May 14, 2024
@ringohoffman
Copy link
Contributor Author
ringohoffman commented May 14, 2024

Cross-posted to dill:

@ezyang ezyang added oncall: pt2 module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects labels May 14, 2024
@ringohoffman
Copy link
Contributor Author
ringohoffman commented May 15, 2024

These appearances of torch._C.dynamo:

.tp_name = "torch._C.dynamo.eval_frame._PyInterpreterFrame",

.tp_name = "torch._C.dynamo.eval_frame._PyInterpreterFrame",

are the only 2 in all of this repo, and there is no such package as torch._C.dynamo, but there is torch._C._dynamo.

This would seem to explain this:

import torch._C._dynamo

print(torch._C._dynamo.eval_frame._CacheEntry.__module__)  # torch._C.dynamo.eval_frame

Should these be changed to _dynamo?

EDIT:

Actually, __module__ is torch._C._dynamo.eval_frame in 2.3.0+...

@ringohoffman
Copy link
Contributor Author

Seems to make some progress, leading to a new error:

import dill

import torch._C._dynamo

print(torch._C._dynamo.eval_frame._CacheEntry.__module__)  # torch._C._dynamo.eval_frame

torch._C._dynamo.eval_frame._CacheEntry.__module__ = "eval_frame"

print(torch._C._dynamo.eval_frame._CacheEntry.__module__)  # eval_frame

with open("fake.txt", "wb") as f:
    dill.dump(torch._C._dynamo.eval_frame._CacheEntry, f)  # PicklingError: Can't pickle <class 'pybind11_builtins.pybind11_type'>: it's not found as pybind11_builtins.pybind11_type

@ringohoffman
Copy link
Contributor Author
ringohoffman commented May 15, 2024

@ringohoffman
Copy link
Contributor Author

I patched pickle._Pickler.save_global (hopefully there is an alternative way to achieve the same without requiring a change the standard library...) so that when __import__("torch._C._dynamo.eval_frame", level=0) fails:

  1. rsplit off the target module ("eval_frame")
  2. load torch._C._dynamo instead
  3. load eval_frame from torch._C._dynamo
  4. finally load the attribute from eval_frame
class _Pickler:
    ...

    def save_global(self, obj, name=None):
        write = self.write
        memo = self.memo

        if name is None:
            name = getattr(obj, '__qualname__', None)
        if name is None:
            name = obj.__name__

        module_name = whichmodule(obj, name)
        try:
            try:
                __import__(module_name, level=0)
                module = sys.modules[module_name]
                obj2, parent = _getattribute(module, name)
            except ModuleNotFoundError:
                parent_name, module_name = module_name.rsplit(".", 1)
                __import__(parent_name, level=0)
                parent = sys.modules[parent_name]
                module, parent = _getattribute(parent, module_name)
                obj2, module = _getattribute(module, name)
        except (ImportError, KeyError, AttributeError):
            raise PicklingError(
                "Can't pickle %r: it's not found as %s.%s" %
                (obj, module_name, name)) from None
        else:
            if obj2 is not obj:
                raise PicklingError(
                    "Can't pickle %r: it's not the same object as %s.%s" %
                    (obj, module_name, name))

        ...

And continued on to get:

TypeError: cannot pickle 'ConfigModuleInstance' object

while trying to serialize torch._dynamo.config, which is a ConfigModuleInstance.

Which I have seen elsewhere:

@ringohoffman
Copy link
Contributor Author

@ezyang this seems like the same exact problem that was solved by:

I am sure we could hack this the same way but I wonder if I'm onto something with my patch to pickle._Pickler.save_global...

@ezyang
Copy link
Contributor
ezyang commented May 16, 2024

Well, it's generally considered quite naughty to monkeypatch stuff in Python standard library as a library. So yeah, probably better to stop exportin gthings as static methods...

@ringohoffman
Copy link
Contributor Author

I meant as a proposed change to the standard library, not as a monkey patch. This is basically the situation we are in:

import pickle
import types

eval_frame = types.ModuleType("eval_frame")

def reset_code():
    ...

reset_code.__module__ = "__main__.eval_frame"

eval_frame.reset_code = reset_code

with open("reset_code.pt", "wb") as f:
    pickle._Pickler(f).dump(eval_frame.reset_code)  # _pickle.PicklingError: Can't pickle <function reset_code at 0x7f2383e651f0>: import of module '__main__.eval_frame' failed

eval_frame is not a package and so is not importable as import my_module.eval_frame. It's kind of like a module that was imported into our package, like how torch/__init__.py imports sys:

import torch
torch.sys.path  # is actually just sys.path

import torch.sys  # ModuleNotFoundError: No module named 'torch.sys'

But unlike sys, you can't resolve eval_frame just by importing it from the global scope. It only exists in that file. It is more akin to the result of ModuleType. It is is only accessible as in: from torch._C._dynamo import eval_frame. So this is basically what I am proposing. If import xxx.yyy fails, try from xxx import yyy.

I will try to add this into the discussion that you guys have already had on cpython.

@ringohoffman
Copy link
Contributor Author

I refined my patch and made the case for the change here:

ringohoffman added a commit to ringohoffman/cpython that referenced this issue May 18, 2024
…age C-modules

There have been recurring issue with PyModule_Create modules in PyTorch; when trying to serialize attributes of these C-modules, pickle fails to import the C-module because it is not a package

This is the current issue that brought this to my attention: pytorch/pytorch#126154

The existing hack to this issue has been to insert the C-module into sys.modules in order to enable pickle to find them: https://github.com/pytorch/pytorch/pull/38136/files#diff-d7e90d0f94b43db763b44fba679a5c1b4cabe3668aaf34f2aee07de8e2d1b2faR524-R528

Instead of relying on this hack, we can change `pickle`'s  approach to loading, which is currently equivalent to `import package.c_module`; instead, we could do `from package import c_module`, which 1) does not care if `c_module` is a package or not 2) is fully backward compatible with the previous approach and 3) slots in nicely to the `fromlist` parameter of `__import__`, which we are already using to load modules in `pickle`
ringohoffman added a commit to ringohoffman/cpython that referenced this issue May 18, 2024
There have been recurring issue with PyModule_Create modules in PyTorch; when trying to serialize attributes of these C-modules, pickle fails to import the C-module because it is not a package

This is the current issue that brought this to my attention: pytorch/pytorch#126154

The existing hack to this issue has been to insert the C-module into sys.modules in order to enable pickle to find them: https://github.com/pytorch/pytorch/pull/38136/files#diff-d7e90d0f94b43db763b44fba679a5c1b4cabe3668aaf34f2aee07de8e2d1b2faR524-R528

Instead of relying on this hack, we can change `pickle`'s  approach to loading, which is currently equivalent to `import package.c_module`; instead, we could do `from package import c_module`, which 1) does not care if `c_module` is a package or not 2) is fully backward compatible with the previous approach and 3) slots in nicely to the `fromlist` parameter of `__import__`, which we are already using to load modules in `pickle`
ringohoffman added a commit to ringohoffman/cpython that referenced this issue May 18, 2024
There have been recurring issue with PyModule_Create modules in PyTorch; when trying to serialize attributes of these C-modules, pickle fails to import the C-module because it is not a package

This is the current issue that brought this to my attention: pytorch/pytorch#126154

The existing hack to this issue has been to insert the C-module into sys.modules in order to enable pickle to find them: https://github.com/pytorch/pytorch/pull/38136/files#diff-d7e90d0f94b43db763b44fba679a5c1b4cabe3668aaf34f2aee07de8e2d1b2faR524-R528

Instead of relying on this hack, we can change `pickle`'s  approach to loading, which is currently equivalent to `import package.c_module`; instead, we could do `from package import c_module`, which 1) does not care if `c_module` is a package or not 2) is fully backward compatible with the previous approach and 3) slots in nicely to the `fromlist` parameter of `__import__`, which we are already using to load modules in `pickle`
@xmfan xmfan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
@ezyang
Copy link
Contributor
ezyang commented May 21, 2024

OK good luck :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0