-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Make compiled models serializable #101107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
This has come a few times so aggregating how this works today and our plans for next steps in the future
|
Thanks for the quick response! I did not mention this in the issue description, but the ability to pickle the compiled functions would also be great (we use (If I'm not mistaken, only some params can be fetched/inferred currently (e.g, |
@mariosasko just wanna make sure I understand, as a first step it sounds like you're mostly interested in knowing exactly which args were using when compiling a model for reproducibility. If so I was also planning on just putting that in the nn module state dict I would really love to just be able to dill or pickle an entire optimized module but there's way too many setattr to make that possible easily but I'll still dig through it to see what's possible |
Any updates on this, or has anyone found a workaround? cc: @msaroufim |
So the simplest workaround is to save the state dict and not the model which we mentioned back when 2.0 was released https://pytorch.org/get-started/pytorch-2.0/#serialization I tried to get saving the model to work directly here #101651 and it did work you could effectively save compiled models directly but the problem was my changes introduced some extra graph breaks across the board which have a performance impact, I couldn't figure it out and I don't have bandwidth to further inspect but if someone would like to revisit I'd be happy to review and merge |
- Torch 2.0 introduced `torch.compile`, a novel utilty to optimize computations via their JIT-compiling into optimized kernels. At the moment, compiled modules cannot be saved (but their states, which are shared with the underlying original module, can), and are not compatible with `torch.func` functional execution either. - As declearn vows to support Torch 2.0, it seems crucial that end- users may use `torch.compile`. However, as long as Torch 1.10-13 versions (and previous versions of DecLearn 2.X) are supported, this handling should not break backward compatibility. - An initial approach was to enable compiling the handled module as part of `TorchModel.__init__`. However, this proves impractical, as it takes away the assumption that end-users should be able to use their customly-prepared module as-is - including pre-compiled ones. This is all the more true as their are many options to the `torch.compile` function, that DecLearn has no purpose handling. - Therefore, the approach implemented here is to detect and handle models that were compiled prior to being input into `TorchModel`. - Here are a few notes and caveats regarding the current implementation: - Some impractical steps were taken to ensure weights and gradients have the same name regardless of whether the module was compiled or not, _and regardless of the specific 2.x version of declearn_. When we move to declearn 3, it will be worth revising. - A positive consequence of the previous note is that the compilation of the module should not impair cases where some clients are using torch 1 and/or an older 2.x version of declearn. - The will to retain user-defined compilation option is mostly lost due to the current lack of recording of these info by torch. This is however expected to evolve in the future, which should enable sharing instructions with clients. See issue 101107 of pytorch: pytorch/pytorch#101107 - A clumsy bugfix was introduced to avoid an issue where the wrapped compiled model would not take weights updates into account when running in evaluation mode. The status of this hack should be kept under close look as the issue I opened to report the bug is treated: pytorch/pytorch#104984
My workaround is to "repair" checkpoints that contain the undesired Save the following script: import sys
import torch
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix) :]
return text
def repair_checkpoint(path):
ckpt = torch.load(path)
in_state_dict = ckpt["model_state_dict"]
pairings = [
(src_key, remove_prefix(src_key, "_orig_mod."))
for src_key in in_state_dict.keys()
]
if all(src_key == dest_key for src_key, dest_key in pairings):
return # Do not write checkpoint if no need to repair!
out_state_dict = {}
for src_key, dest_key in pairings:
print(f"{src_key} ==> {dest_key}")
out_state_dict[dest_key] = in_state_dict[src_key]
ckpt["model_state_dict"] = out_state_dict
torch.save(ckpt, path)
if __name__ == "__main__":
paths = sys.argv[1:]
for path in paths:
print(path)
repair_checkpoint(path)
print("========") Then: python checkpoint_unwrap_orig_model.py **/*.pth NOTE: In my checkpoints, the state_dict is actually inside |
@msaroufim I tried your workaround (torch.compile my model, then save the state_dict, then load a new non-compiled version of my model, finally insert the saved state_dict) and the model is slow (non-compiled). Am I doing something wrong? Or are you saying that I have to recompile regardless of whether I load a compiled state_dict or not? |
If possible (i.e. because you can change the code that is generating the checkpoints), you can save Generic code that doesn't know whether a model is compiled or not could do something like: |
This is a duplicate of #93470 |
I usually override the state_dict() method and load_state_dict() method, using a specific structure to solve most problems, for example: from collections import OrderedDict
import torch.nn as nn
class CustomNet(nn.Module):
def __init__(self):
self.kernel: nn.Module = ...
# override state_dict() method
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if destination is None:
destination = OrderedDict()
prefix = "" # remove prefix
destination.update([('kernel', self.kernel.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)])
return destination
# override load_state_dict()
def load_state_dict(self, state_dict, ...):
self.kernel.load_state_dict(state_dict) This structure can fix '_orig_mod' prefix. However, the above code is just an example of the method and cannot be run. You should adjust the structure accordingly according to your own code. |
@msaroufim any updates? |
Nope unsassigning myself for now since i haven't had time to keep fixing issues here |
Isn't |
@fxmarty It would be great to be able to save a model that has run |
@rfeinman Could you clarify what you mean by "torch.export.export only compiles a single path from the program control flow"? Export should be able to handle control flow if it is rewritten using |
@angelayi my understanding is based off of the explanation here: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#comparison-to-torchscript-and-fx-tracing
I think the ability of |
@rfeinman that makes sense! torch.export wants to get a full graph representation of the code so it requires these code rewrites, instead of defaulting to the python code, which is what torch.compile does. |
Maybe a single GraphModule is able to represent torch.cond then? I was not aware of that edit: yep: import torch
def true_fn(x: torch.Tensor):
return x.cos() + x.sin()
def false_fn(x: torch.Tensor):
return x.sin()
class DynamicShapeCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on dynamic shape predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return x.cos()
def false_fn(x: torch.Tensor):
return x.sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
dyn_shape_mod = DynamicShapeCondPredicate()
res = dyn_shape_mod(torch.randn(10, 10))
##
from torch.export import export
example_args = (torch.randn(10, 10),)
exported_program = export(
DynamicShapeCondPredicate(), args=example_args
)
print(exported_program)
print(exported_program.graph)
now when it comes to which hardware/compiler is able to consume that... |
Curious has this issue been resolved? Essentially the goal is to saved a compiled model, so that we don't have to re-compile each time. (i.e. serialize the compiled model). |
I'm curious as well, this would be a very nice feature to have |
cc @zhxchen17 |
+1, this would be very useful. Right now, I have to recompile the model on each Docker container startup. This is especially bad when dynamically scaling up Kubernetes microservices or with Lambda functions, when a requests surge occurs. Particularly since I use very static embeddings models (e.g. CLIP, DINOv2), with inputs scaled to the same size, so serializing those models after compilation should be relatively easy. |
I've tried using the new it does seem faster, but still takes a while artifacts = torch.compiler.save_cache_artifacts()
assert artifacts is not None
artifact_bytes, cache_info = artifacts
with gzip.open(artifacts_filename, "wb") as f:
f.write(artifact_bytes)
log.info(f"successfully saved artifacts: {cache_info}") with gzip.open(artifacts_filename, "rb") as f:
artifact_bytes = f.read()
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) |
@vgoklani Yeah probably because load_cache_artifacts() cache mostly aotautograd and inductor artifacts. Just a small update that we are developing a different frontend which will further strip the compilation artifacts from dynamo which hopefully will reduce the loading time to near zero. |
🐛 Describe the bug
Serializing a compiled model with
pickle
fails withCan't pickle local object 'convert_frame.<locals>._convert_frame'
andcannot pickle 'ConfigModuleInstance' object
when usingdill
.A Colab with an example:
https://colab.research.google.com/drive/1v6jUUq86ql1Era4X47cIDj7bzrrz2RZe?usp=sharing
In Hugging Face Datasets, this error stops us from generating (deterministic) hashes for transforms (functions) that reference a compiled model, meaning such transforms cannot be cached and must be re-computed each time when transforming a dataset.
(The "export" API for the compiled models would also work for us.)
Error logs
No response
Minified repro
No response
Versions
Colab env with torch 2.0.1 installed
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @soumith @wconstab @ngimel
The text was updated successfully, but these errors were encountered: