8000 Make compiled models serializable · Issue #101107 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Make compiled models serializable #101107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mariosasko opened this 8000 issue May 10, 2023 · 24 comments
Open

Make compiled models serializable #101107

mariosasko opened this issue May 10, 2023 · 24 comments
Labels
compile-cache module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mariosasko
Copy link
Contributor
mariosasko commented May 10, 2023

🐛 Describe the bug

Serializing a compiled model with pickle fails with Can't pickle local object 'convert_frame.<locals>._convert_frame' and cannot pickle 'ConfigModuleInstance' object when using dill.

A Colab with an example:
https://colab.research.google.com/drive/1v6jUUq86ql1Era4X47cIDj7bzrrz2RZe?usp=sharing

In Hugging Face Datasets, this error stops us from generating (deterministic) hashes for transforms (functions) that reference a compiled model, meaning such transforms cannot be cached and must be re-computed each time when transforming a dataset.

(The "export" API for the compiled models would also work for us.)

Error logs

No response

Minified repro

No response

Versions

Colab env with torch 2.0.1 installed
PyTorch version: 2.0.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.25.2
Libc version: glibc-2.31

Python version: 3.10.11 (main, Apr  5 2023, 14:15:10) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.147+-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   46 bits physical, 48 bits virtual
CPU(s):                          2
On-line CPU(s) list:             0,1
Thread(s) per core:              2
Core(s) per socket:              1
Socket(s):                       1
NUMA node(s):                    1
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           79
Model name:                      Intel(R) Xeon(R) CPU @ 2.20GHz
Stepping:                        0
CPU MHz:                         2200.196
BogoMIPS:                        4400.39
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       32 KiB
L1i cache:                       32 KiB
L2 cache:                        256 KiB
L3 cache:                        55 MiB
NUMA node0 CPU(s):               0,1
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Vulnerable; SMT Host state unknown
Vulnerability Meltdown:          Vulnerable
Vulnerability Mmio stale data:   Vulnerable
Vulnerability Retbleed:          Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Vulnerable
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] torch==2.0.1+cu118
[pip3] torchaudio==2.0.2+cu118
[pip3] torchdata==0.6.0
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.15.1
[pip3] torchvision==0.15.2+cu118
[pip3] triton==2.0.0
[conda] Could not collect

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @soumith @wconstab @ngimel

@msaroufim msaroufim self-assigned this May 10, 2023
@msaroufim
Copy link
Member
msaroufim commented May 10, 2023

This has come a few times so aggregating how this works today and our plans for next steps in the future

  1. Indeed you can't pickle an optimized module but you can pickle the original module because the weights are shared https://pytorch.org/get-started/pytorch-2.0/#serialization - I'm planning on adding a simple get and set state here to unwrap the original module automatically for people
  2. Because 1 is annoying we recently introduced an in place module compilation API that would make saving and loading work torch.save/load torch.compiled models #97565
  3. To improve reproducibility we're also thinking of saving all the config arguments that were passed to torch.compile and persist them when you save and load a model
  4. But unfortunately 2 doesn't solve the problem of having to recompile a model when you load it so cold starts for inference are bad, I'll have a POC working to solve this very soon but the core idea is to dump the entire inductor (note that the inductor cache includes a triton, inductor and soon an fx cache) cache into a state dict and reload it later https://github.com/msaroufim/mlsys-experiments/blob/main/compile-checkpoint/save-hook.py
  5. Conceptually 4 should work if you assume that the same machine type will be used for inference and training, torch.load and save have a contract in that they guarantee working across devices and this might not be true for us, so instead maybe we just write some docs to recommend users copy their caches to some networked file system?
  6. Export obviously is a good solution but so far no date for an official release, export is also focused on environments without python available but if python is available I think 3 will work just fine - EDIT: export is now available
  7. It might be possible to pickle/dill the entire compiled module but I haven't figured out how yet since there's lots of dynamic behavior but dill at least is powerful enough to pickle a python interpreter so i feel like it should work. One thing we can do is when trying to pickle the optimized module, we automatically unwrap and pickle the unoptimized one - EDIT: I got stuck working on this because of extra graph breaks but I'd be happy to help merge if someone wants to pick this up Save/Load OptimizedModule #101651

@mariosasko
Copy link
Contributor Author

Thanks for the quick response! I did not mention this in the issue description, but the ability to pickle the compiled functions would also be great (we use dill, which can also pickle functions). Considering it's already possible to fetch the original function/method of a compiled function/model, the simplest solution that would work for us is exposing the params passed to torch.compile (e.g., as an attribute of the compile context). Then, we could define a simple reduction function to make the pickling possible.

(If I'm not mistaken, only some params can be fetched/inferred currently (e.g, disable))

@msaroufim
Copy link
Member
msaroufim commented May 11, 2023

@mariosasko just wanna make sure I understand, as a first step it sounds like you're mostly interested in knowing exactly which args were using when compiling a model for reproducibility. If so I was also planning on just putting that in the nn module state dict

I would really love to just be able to dill or pickle an entire optimized module but there's way too many setattr to make that possible easily but I'll still dig through it to see what's possible

@Chillee Chillee added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 22, 2023
@varunshenoy
Copy link
varunshenoy commented Aug 29, 2023

Any updates on this, or has anyone found a workaround?

cc: @msaroufim

@msaroufim
Copy link
Member
msaroufim commented Aug 29, 2023

So the simplest workaround is to save the state dict and not the model which we mentioned back when 2.0 was released https://pytorch.org/get-started/pytorch-2.0/#serialization

I tried to get saving the model to work directly here #101651 and it did work you could effectively save compiled models directly but the problem was my changes introduced some extra graph breaks across the board which have a performance impact, I couldn't figure it out and I don't have bandwidth to further inspect but if someone would like to revisit I'd be happy to review and merge

pandrey-fr added a commit to DecLearn/declearn that referenced this issue Aug 30, 2023
- Torch 2.0 introduced `torch.compile`, a novel utilty to optimize
  computations via their JIT-compiling into optimized kernels. At
  the moment, compiled modules cannot be saved (but their states,
  which are shared with the underlying original module, can), and
  are not compatible with `torch.func` functional execution either.
- As declearn vows to support Torch 2.0, it seems crucial that end-
  users may use `torch.compile`. However, as long as Torch 1.10-13
  versions (and previous versions of DecLearn 2.X) are supported,
  this handling should not break backward compatibility.
- An initial approach was to enable compiling the handled module as
  part of `TorchModel.__init__`. However, this proves impractical,
  as it takes away the assumption that end-users should be able to
  use their customly-prepared module as-is - including pre-compiled
  ones. This is all the more true as their are many options to the
  `torch.compile` function, that DecLearn has no purpose handling.
- Therefore, the approach implemented here is to detect and handle
  models that were compiled prior to being input into `TorchModel`.

- Here are a few notes and caveats regarding the current implementation:
  - Some impractical steps were taken to ensure weights and gradients
    have the same name regardless of whether the module was compiled
    or not, _and regardless of the specific 2.x version of declearn_.
    When we move to declearn 3, it will be worth revising.
  - A positive consequence of the previous note is that the compilation
    of the module should not impair cases where some clients are using
    torch 1 and/or an older 2.x version of declearn.
  - The will to retain user-defined compilation option is mostly lost
    due to the current lack of recording of these info by torch. This
    is however expected to evolve in the future, which should enable
    sharing instructions with clients. See issue 101107 of pytorch:
    pytorch/pytorch#101107
  - A clumsy bugfix was introduced to avoid an issue where the wrapped
    compiled model would not take weights updates into account when
    running in evaluation mode. The status of this hack should be kept
    under close look as the issue I opened to report the bug is treated:
    pytorch/pytorch#104984
@YodaEmbedding
Copy link
YodaEmbedding commented Nov 8, 2023

My workaround is to "repair" checkpoints that contain the undesired "_orig_mod." prefix.

Save the following script:

import sys
import torch


def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text


def repair_checkpoint(path):
    ckpt = torch.load(path)
    in_state_dict = ckpt["model_state_dict"]
    pairings = [
        (src_key, remove_prefix(src_key, "_orig_mod."))
        for src_key in in_state_dict.keys()
    ]
    if all(src_key == dest_key for src_key, dest_key in pairings):
        return  # Do not write checkpoint if no need to repair!
    out_state_dict = {}
    for src_key, dest_key in pairings:
        print(f"{src_key}  ==>  {dest_key}")
        out_state_dict[dest_key] = in_state_dict[src_key]
    ckpt["model_state_dict"] = out_state_dict
    torch.save(ckpt, path)


if __name__ == "__main__":
    paths = sys.argv[1:]
    for path in paths:
        print(path)
        repair_checkpoint(path)
        print("========")

Then:

python checkpoint_unwrap_orig_model.py **/*.pth

NOTE: In my checkpoints, the state_dict is actually inside ckpt["model_state_dict"]. If yours is in a different place, adjust that as necessary, e.g. ckpt if your state_dict is exactly the root of the checkpoint.

@wilson97
Copy link
wilson97 commented Nov 9, 2023

@msaroufim I tried your workaround (torch.compile my model, then save the state_dict, then load a new non-compiled version of my model, finally insert the saved state_dict) and the model is slow (non-compiled). Am I doing something wrong? Or are you saying that I have to recompile regardless of whether I load a compiled state_dict or not?

@pallgeuer
Copy link
pallgeuer commented Dec 27, 2023

My workaround is to "repair" checkpoints that contain the undesired "_orig_mod." prefix.

Save the following script:

import sys
import torch


def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text


def repair_checkpoint(path):
    ckpt = torch.load(path)
    in_state_dict = ckpt["model_state_dict"]
    pairings = [
        (src_key, remove_prefix(src_key, "_orig_mod."))
        for src_key in in_state_dict.keys()
    ]
    if all(src_key == dest_key for src_key, dest_key in pairings):
        return  # Do not write checkpoint if no need to repair!
    out_state_dict = {}
    for src_key, dest_key in pairings:
        print(f"{src_key}  ==>  {dest_key}")
        out_state_dict[dest_key] = in_state_dict[src_key]
    ckpt["model_state_dict"] = out_state_dict
    torch.save(ckpt, path)


if __name__ == "__main__":
    paths = sys.argv[1:]
    for path in paths:
        print(path)
        repair_checkpoint(path)
        print("========")

Then:

python checkpoint_unwrap_orig_model.py **/*.pth

NOTE: In my checkpoints, the state_dict is actually inside ckpt["model_state_dict"]. If yours is in a different place, adjust that as necessary, e.g. ckpt if your state_dict is exactly the root of the checkpoint.

If possible (i.e. because you can change the code that is generating the checkpoints), you can save model._orig_mod.state_dict() instead of model.state_dict() for compiled models. This will avoid the _orig_mod. prefix everywhere, and therefore not require any fixing on load. You just need to load_state_dict() and torch.compile() after loading.

Generic code that doesn't know whether a model is compiled or not could do something like:
getattr(model, '_orig_mod', model).state_dict()

@fxmarty
Copy link
fxmarty commented Feb 20, 2024

This is a duplicate of #93470

@ecstayalive
Copy link
ecstayalive commented Feb 28, 2024

I usually override the state_dict() method and load_state_dict() method, using a specific structure to solve most problems, for example:

from collections import OrderedDict
import torch.nn as nn


class CustomNet(nn.Module):
    def __init__(self):
        self.kernel: nn.Module = ...
    # override state_dict() method
    def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
        if destination is None:
            destination = OrderedDict()
        prefix = ""  # remove prefix
        destination.update([('kernel', self.kernel.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)])
        return destination
    # override load_state_dict()
    def load_state_dict(self, state_dict, ...):
        self.kernel.load_state_dict(state_dict)

This structure can fix '_orig_mod' prefix. However, the above code is just an example of the method and cannot be run. You should adjust the structure accordingly according to your own code.

@tugsbayasgalan
Copy link
Contributor

@msaroufim any updates?

@msaroufim msaroufim removed their assignment Mar 13, 2024
@msaroufim
Copy link
Member

Nope unsassigning myself for now since i haven't had time to keep fixing issues here

@fxmarty
Copy link
fxmarty commented Mar 18, 2024

Isn't torch.export.export kind of similar to what is requested here? https://pytorch.org/docs/stable/export.html

@rfeinman
Copy link
rfeinman commented Jun 2, 2024

@fxmarty torch.export.export only compiles a single path from the program control flow, as far as I understand.

It would be great to be able to save a model that has run torch.compile so that we do not need to re-compile each time we launch a program! +1 for this

@angelayi
Copy link
Contributor
angelayi commented Jun 3, 2024

@rfeinman Could you clarify what you mean by "torch.export.export only compiles a single path from the program control flow"? Export should be able to handle control flow if it is rewritten using torch.cond.

@rfeinman
Copy link
rfeinman commented Jun 3, 2024

@angelayi my understanding is based off of the explanation here: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#comparison-to-torchscript-and-fx-tracing

Primarily, the advantage of torch.compile lies in its ability to handle arbitrary Python code with minimal changes to existing code.

One case that torch.compile can handle that other compiler solutions struggle with is data-dependent control flow (the if x.sum() < 0: line below).

TorchScript tracing f1 results in silently incorrect results, since only the actual control flow path is traced.

I think the ability of torch.compile to handle arbitrary Python code with minimal changes is a very nice feature, and it would be great if this feature could transfer to serialization (i.e., if we don't have to swap in torch.cond, etc).

@angelayi
Copy link
Contributor
angelayi commented Jun 3, 2024

@rfeinman that makes sense! torch.export wants to get a full graph representation of the code so it requires these code rewrites, instead of defaulting to the python code, which is what torch.compile does.

@fxmarty
Copy link
fxmarty commented Jun 4, 2024

Maybe a single GraphModule is able to represent torch.cond then? I was not aware of that

edit: yep:

import torch

def true_fn(x: torch.Tensor):
    return x.cos() + x.sin()

def false_fn(x: torch.Tensor):
    return x.sin()

class DynamicShapeCondPredicate(torch.nn.Module):
    """
    A basic usage of cond based on dynamic shape predicate.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        def true_fn(x: torch.Tensor):
            return x.cos()

        def false_fn(x: torch.Tensor):
            return x.sin()

        return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))

dyn_shape_mod = DynamicShapeCondPredicate()

res = dyn_shape_mod(torch.randn(10, 10))

##

from torch.export import export

example_args = (torch.randn(10, 10),)

exported_program = export(
    DynamicShapeCondPredicate(), args=example_args
)
print(exported_program)
print(exported_program.graph)
graph():
    %l_x_ : [num_users=1] = placeholder[target=l_x_]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (True, %true_graph_0, %false_graph_0, [%l_x_]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%conditional, 0), kwargs = {})
    return (getitem,)

now when it comes to which hardware/compiler is able to consume that...

@vgoklani
Copy link

Curious has this issue been resolved?

Essentially the goal is to saved a compiled model, so that we don't have to re-compile each time. (i.e. serialize the compiled model).

@fernandoFernandeSantos
Copy link

I'm curious as well, this would be a very nice feature to have

@ezyang
Copy link
Contributor
ezyang commented Jan 27, 2025

cc @zhxchen17

@j-adamczyk
Copy link

+1, this would be very useful. Right now, I have to recompile the model on each Docker container startup. This is especially bad when dynamically scaling up Kubernetes microservices or with Lambda functions, when a requests surge occurs. Particularly since I use very static embeddings models (e.g. CLIP, DINOv2), with inputs scaled to the same size, so serializing those models after compilation should be relatively easy.

@vgoklani
Copy link

I've tried using the new artifacts

it does seem faster, but still takes a while

    artifacts = torch.compiler.save_cache_artifacts()

    assert artifacts is not None
    artifact_bytes, cache_info = artifacts

    with gzip.open(artifacts_filename, "wb") as f:
        f.write(artifact_bytes)
        log.info(f"successfully saved artifacts: {cache_info}")
        with gzip.open(artifacts_filename, "rb") as f:
            artifact_bytes = f.read()

        cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)

@zhxchen17
Copy link
Contributor

@vgoklani Yeah probably because load_cache_artifacts() cache mostly aotautograd and inductor artifacts.

Just a small update that we are developing a different frontend which will further strip the compilation artifacts from dynamo which hopefully will reduce the loading time to near zero.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compile-cache module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

0