-
Notifications
You must be signed in to change notification settings - Fork 24.2k
.eval()
freezes weights of torch.compile
modules in inference mode
#104984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I conducted some further tests, and it seems that the issue arises from some interaction between This test, where I do not use # Linear model with sigmoid activation.
model = torch.nn.Sequential(
torch.nn.Linear(128, 1),
torch.nn.Sigmoid(),
)
model = torch.compile(model)
# Stub input data.
x_dat = torch.randn(size=(4, 128))
# Gather the model's initial state, and generate new (random) ones.
state_a = {key: val.clone() for key, val in model.state_dict().items()}
state_b = {key: torch.rand_like(val) for key, val in state_a.items()}
# Turn the model to eval mode and run inference.
model.eval()
eval_a = model(x_dat)
# Update the model weights and run inference.
model.load_state_dict(state_b)
eval_b = model(x_dat)
# Turn the model to train mode and re-run inference.
model.train()
train_b = model(x_dat)
# These assertions pass.
assert (eval_a != eval_b).any()
assert (eval_b == train_b).all() However, this one, where I use # Create the model anew and gather/generate states again.
model = torch.nn.Sequential(
torch.nn.Linear(128, 1),
torch.nn.Sigmoid(),
)
model = torch.compile(model)
state_a = {key: val.clone() for key, val in model.state_dict().items()}
state_b = {key: torch.rand_like(val) for key, val in state_a.items()}
# Do the same as before, but in eval mode, use `no_grad` context.
model.eval()
with torch.no_grad():
eval_a = model(x_dat)
model.load_state_dict(state_b)
with torch.no_grad():
eval_b = model(x_dat)
model.train()
train_b = model(x_dat)
# Both assertions fail!
assert (eval_a != eval_b).any()
assert (eval_b == train_b).all() |
I also tried to use # Create the model anew and gather/generate states again.
model = torch.nn.Sequential(
torch.nn.Linear(128, 1),
torch.nn.Sigmoid(),
)
model = torch.compile(model)
state_a = {key: val.clone() for key, val in model.state_dict().items()}
state_b = {key: torch.rand_like(val) for key, val in state_a.items()}
# Run inference within `no_grad`, but in train mode.
with torch.no_grad():
eval_a = model(x_dat)
model.load_state_dict(state_b)
with torch.no_grad():
eval_b = model(x_dat)
# This passes: weights were properly used.
assert (eval_a != eval_b).any()
# This raises:
# RuntimeError: addmm(): functions with out=... arguments don't support
# automatic differentiation, but one of the arguments requires grad.
train_b = model(x_dat) |
From the test above, I infer that the first forward should not be run within The result is that:
# Create the model anew and gather/generate states again.
model = torch.nn.Sequential(
torch.nn.Linear(128, 1),
torch.nn.Sigmoid(),
)
model = torch.compile(model)
state_a = {key: val.clone() for key, val in model.state_dict().items()}
state_b = {key: torch.rand_like(val) for key, val in state_a.items()}
# Run a forward pass in train mode.
train_a = model(x_dat)
# Switch to eval and run inference within `no_grad`.
model.eval()
with torch.no_grad():
eval_a = model(x_dat)
# Update weights and run inference (still in eval + no_grad).
model.load_state_dict(state_b)
with torch.no_grad():
eval_b = model(x_dat)
# Switch back to train and run the forward again.
model.train()
train_b = model(x_dat)
# These pass...
assert (train_a == eval_a).all()
assert (train_a != train_b).any()
# ... but this fails!
assert (eval_a != eval_b).any() Note that if I re-run the second eval-mode forward after the train-mode one, weights' update are again ignored. |
OK, I think I get it. The issue seems to arise from the first eval-mode call happening within a # Create the model anew and gather/generate states again.
model = torch.nn.Sequential(
torch.nn.Linear(128, 1),
torch.nn.Sigmoid(),
)
model = torch.compile(model)
state_a = {key: val.clone() for key, val in model.state_dict().items()}
state_b = {key: torch.rand_like(val) for key, val in state_a.items()}
# Run a forward pass in train mode, and one in eval mode *without no_grad*.
train_a = model(x_dat)
model.eval()
eval_a = model(x_dat)
# Update weights and run inference again, in eval + no-grad
model.load_state_dict(state_b)
with torch.no_grad():
eval_b = model(x_dat)
# Switch back to train and run the forward again.
model.train()
train_b = model(x_dat)
# These all pass.
assert (train_a == eval_a).all()
assert (train_b == eval_b).all()
assert (train_a != train_b).any()
assert (eval_a != eval_b).any()
assert train_a.grad_fn is not None
assert train_b.grad_fn is not None
assert eval_a.grad_fn is not None
assert eval_b.grad_fn is None I am glad that I identified this, however it still seems obscure to me - at least, the documentation did not lead me into anticipating this behavior. Could someone please let me know whether this is to be expected, and if so perhaps work on documenting it (or point me out to a bit of documentation I might have missed)? |
Hi, thanks for your investigation. I tested with 20230709 nightly and the issue is not re-produced. Would you upgrade to a latest nightly and check? thanks. |
Hi, |
Sure. And before the new release comes out you can install a recent nightly build. Please refer "INSTALL PYTORCH" section in pytorch.org page for detail install command. |
Requiring end-users of my own torch-dependent package to install a nightly version is not a very neat solution, so I will stick to my ugly fix. But thank you anyway, and I will be looking forward for the release of the 2.1 version. |
- Torch 2.0 introduced `torch.compile`, a novel utilty to optimize computations via their JIT-compiling into optimized kernels. At the moment, compiled modules cannot be saved (but their states, which are shared with the underlying original module, can), and are not compatible with `torch.func` functional execution either. - As declearn vows to support Torch 2.0, it seems crucial that end- users may use `torch.compile`. However, as long as Torch 1.10-13 versions (and previous versions of DecLearn 2.X) are supported, this handling should not break backward compatibility. - An initial approach was to enable compiling the handled module as part of `TorchModel.__init__`. However, this proves impractical, as it takes away the assumption that end-users should be able to use their customly-prepared module as-is - including pre-compiled ones. This is all the more true as their are many options to the `torch.compile` function, that DecLearn has no purpose handling. - Therefore, the approach implemented here is to detect and handle models that were compiled prior to being input into `TorchModel`. - Here are a few notes and caveats regarding the current implementation: - Some impractical steps were taken to ensure weights and gradients have the same name regardless of whether the module was compiled or not, _and regardless of the specific 2.x version of declearn_. When we move to declearn 3, it will be worth revising. - A positive consequence of the previous note is that the compilation of the module should not impair cases where some clients are using torch 1 and/or an older 2.x version of declearn. - The will to retain user-defined compilation option is mostly lost due to the current lack of recording of these info by torch. This is however expected to evolve in the future, which should enable sharing instructions with clients. See issue 101107 of pytorch: pytorch/pytorch#101107 - A clumsy bugfix was introduced to avoid an issue where the wrapped compiled model would not take weights updates into account when running in evaluation mode. The status of this hack should be kept under close look as the issue I opened to report the bug is treated: pytorch/pytorch#104984
🐛 Describe the bug
When using
torch.compile
on a given module, calling that module'seval
method results in freezing the weights used in inference, so that callingcompiled_module(inputs)
returns deterministic results even after updating the module'sstate_dict
. Calling thetrain
method does not fix this issue.I have trouble understanding why this is the case, and cannot come up with a solution to both (a) use
torch.compile
to optimize computations (b) useeval
to ensure inference runs properly (e.g. to disable dropout) (c) be able to update the model and keep training and/or evaluating it from the new input state.Here is some minimal example to reproduce the issue:
Versions
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305
The text was updated successfully, but these errors were encountered: