8000 `.eval()` freezes weights of `torch.compile` modules in inference mode · Issue #104984 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

.eval() freezes weights of torch.compile modules in inference mode #104984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
pandrey-fr opened this issue Jul 11, 2023 · 8 comments
Closed

.eval() freezes weights of torch.compile modules in inference mode #104984

pandrey-fr opened this issue Jul 11, 2023 · 8 comments
Labels
needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: pt2

Comments

@pandrey-fr
Copy link
pandrey-fr commented Jul 11, 2023

🐛 Describe the bug

When using torch.compile on a given module, calling that module's eval method results in freezing the weights used in inference, so that calling compiled_module(inputs) returns deterministic results even after updating the module's state_dict. Calling the train method does not fix this issue.

I have trouble understanding why this is the case, and cannot come up with a solution to both (a) use torch.compile to optimize computations (b) use eval to ensure inference runs properly (e.g. to disable dropout) (c) be able to update the model and keep training and/or evaluating it from the new input state.

Here is some minimal example to reproduce the issue:

import torch


def test_run_inference_twice(
    model: torch.nn.Module,
    x_dat: torch.Tensor,
) -> None:
    """Run the model in inference twice, updating states in-between."""
    # Gather the model's initial state, and generate new (random) ones.
    state_a = {key: val.clone() for key, val in model.state_dict().items()}
    state_b = {key: torch.rand_like(val) for key, val in state_a.items()}

    # Compute predictions in inference mode with the old states.
    model.eval()
    with torch.no_grad():
       preds_a = model(x_dat)
    model.train()

    # Change the model's state.
    model.load_state_dict(state_b)

    # Compute predictions in inference mode with the new states.
    model.eval()
    with torch.no_grad():
        preds_b = model(x_dat)
    model.train()

    # Assert that predictions defer (as they should).
    assert (preds_a != preds_b).any(), "Predictions are the same!"


def test_run_inference_twice_no_eval(
    model: torch.nn.Module,
    x_dat: torch.Tensor,
) -> None:
    """Copy of `test_run_inference_twice`, without `.eval()` / `.train()`."""
    # Gather the model's initial state, and generate new (random) ones.
    state_a = {key: val.clone() for key, val in model.state_dict().items()}
    state_b = {key: torch.rand_like(val) for key, val in state_a.items()}

    # Compute predictions in inference mode with the old states.
    with torch.no_grad():
       preds_a = model(x_dat)

    # Change the model's state.
    model.load_state_dict(state_b)

    # Compute predictions in inference mode with the new states.
    with torch.no_grad():
        preds_b = model(x_dat)

    # Assert that predictions defer (as they should).
    assert (preds_a != preds_b).any(), "Predictions are the same!"


# Linear model with sigmoid activation.
model = torch.nn.Sequential(
    torch.nn.Linear(128, 1),
    torch.nn.Sigmoid(),
)

# Stub input data.
x_dat = torch.randn(size=(4, 128))

# Run without torch.compile: both tests pass.
test_run_inference_twice(model, x_dat)
test_run_inference_twice_no_eval(model, x_dat)

# Run with torch.compile: the first test fails...
compiled_model = torch.compile(model)
test_run_inference_twice(compiled_model, x_dat)  # raises AssertionError
# ... but the second one works.
compiled_model = torch.compile(model)
test_run_inference_twice_no_eval(compiled_model, x_dat)

Versions

Collecting environment information...
PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: glibc-2.35

Python version: 3.10.4 (main, Jun 29 2022, 12:14:53) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.0-46-generic-x86_64-with-glibc2.35
Is CUDA available: False

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   39 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          8
On-line CPU(s) list:             0-7
Vendor ID:                       GenuineIntel
Model name:                      11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz
CPU family:                      6
Model:                           140

Versions of relevant libraries:
[pip3] flake8==5.0.4
[pip3] functorch==2.0.0
[pip3] mypy==1.2.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] torch==2.0.1
[pip3] triton==2.0.0
[conda] Could not collect

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305

@pandrey-fr
Copy link
Author

I conducted some further tests, and it seems that the issue arises from some interaction between model.eval() and the torch.no_grad() context.

This test, where I do not use torch.no_grad(), passes as expected:

# Linear model with sigmoid activation.
model = torch.nn.Sequential(
    torch.nn.Linear(128, 1),
    torch.nn.Sigmoid(),
)
model = torch.compile(model)

# Stub input data.
x_dat = torch.randn(size=(4, 128))

# Gather the model's initial state, and generate new (random) ones.
state_a = {key: val.clone() for key, val in model.state_dict().items()}
state_b = {key: torch.rand_like(val) for key, val in state_a.items()}

# Turn the model to eval mode and run inference.
model.eval()
eval_a = model(x_dat)
# Update the model weights and run inference.
model.load_state_dict(state_b)
eval_b = model(x_dat)
# Turn the model to train mode and re-run inference.
model.train()
train_b = model(x_dat)

# These assertions pass.
assert (eval_a != eval_b).any()
assert (eval_b == train_b).all()

However, this one, where I use torch.no_grad() in evaluation mode, the wrong behavior arises again, i.e. the weights are properly updated, but not properly used within the no_grad context.

# Create the model anew and gather/generate states again.
model = torch.nn.Sequential(
    torch.nn.Linear(128, 1),
    torch.nn.Sigmoid(),
)
model = torch.compile(model)
state_a = {key: val.clone() for key, val in model.state_dict().items()}
state_b = {key: torch.rand_like(val) for key, val in state_a.items()}

# Do the same as before, but in eval mode, use `no_grad` context.
model.eval()
with torch.no_grad():
    eval_a = model(x_dat)
model.load_state_dict(state_b)
with torch.no_grad():
    eval_b = model(x_dat)
model.train()
train_b = model(x_dat)

# Both assertions fail!
assert (eval_a != eval_b).any()
assert (eval_b == train_b).all()

@pandrey-fr
Copy link
Author

I also tried to use torch.no_grad without switching to eval mode. This results in the compiled model seemingly not comprising auto-differentiation:

# Create the model anew and gather/generate states again.
model = torch.nn.Sequential(
    torch.nn.Linear(128, 1),
    torch.nn.Sigmoid(),
)
model = torch.compile(model)
state_a = {key: val.clone() for key, val in model.state_dict().items()}
state_b = {key: torch.rand_like(val) for key, val in state_a.items()}

# Run inference within `no_grad`, but in train mode.
with torch.no_grad():
    eval_a = model(x_dat)

model.load_state_dict(state_b)
with torch.no_grad():
    eval_b = model(x_dat)

# This passes: weights were properly used.
assert (eval_a != eval_b).any()

# This raises:
# RuntimeError: addmm(): functions with out=... arguments don't support
# automatic differentiation, but one of the arguments requires grad.
train_b = model(x_dat)

@pandrey-fr
Copy link
Author
pandrey-fr commented Jul 12, 2023

From the test above, I infer that the first forward should not be run within no_grad, so that the compiled graph comprises auto-differentiation operations. I thus tried to do: train-mode forward; eval-mode forward; weights update; eval-mode forward; train-mode forward.

The result is that:

  • the first two forward work properly (same results, one with a grad_fn, the other without);
  • the two train-mode forwards work properly (distinct results based on weights' update);
  • the two eval-mode forwards yield the same result: again, weights' update is ignored.
# Create the model anew and gather/generate states again.
model = torch.nn.Sequential(
    torch.nn.Linear(128, 1),
    torch.nn.Sigmoid(),
)
model = torch.compile(model)
state_a = {key: val.clone() for key, val in model.state_dict().items()}
state_b = {key: torch.rand_like(val) for key, val in state_a.items()}

# Run a forward pass in train mode.
train_a = model(x_dat)

# Switch to eval and run inference within `no_grad`.
model.eval()
with torch.no_grad():
    eval_a = model(x_dat)

# Update weights and run inference (still in eval + no_grad).
model.load_state_dict(state_b)
with torch.no_grad():
    eval_b = model(x_dat)

# Switch back to train and run the forward again.
model.train()
train_b = model(x_dat)

# These pass...
assert (train_a == eval_a).all()
assert (train_a != train_b).any()
# ... but this fails!
assert (eval_a != eval_b).any()

Note that if I re-run the second eval-mode forward after the train-mode one, weights' update are again ignored.
Edit: it fails as well if I run inference in eval mode without torch.no_grad, returning yet results based on the initial set of weights, and without a grad_fn. Therefore it seems as though the eval mode switches to a model that is forcefully in no-grad mode and with frozen, inaccessible/unupdatable weights due to the fact that the first eval-mode call happened inside a no_grad context.

@pandrey-fr
Copy link
Author

OK, I think I get it. The issue seems to arise from the first eval-mode call happening within a no_grad context. If I run the first train and eval mode forward passes with autodiff, then the remainder of forward passes seem to work properly (taking weights' update into account), including no_grad-context ones (which properly end up not having an attached grad_fn).

# Create the model anew and gather/generate states again.
model = torch.nn.Sequential(
    torch.nn.Linear(128, 1),
    torch.nn.Sigmoid(),
)
model = torch.compile(model)
state_a = {key: val.clone() for key, val in model.state_dict().items()}
state_b = {key: torch.rand_like(val) for key, val in state_a.items()}

# Run a forward pass in train mode, and one in eval mode *without no_grad*.
train_a = model(x_dat)
model.eval()
eval_a = model(x_dat)

# Update weights and run inference again, in eval + no-grad
model.load_state_dict(state_b)
with torch.no_grad():
    eval_b = model(x_dat)

# Switch back to train and run the forward again.
model.train()
train_b = model(x_dat)

# These all pass.
assert (train_a == eval_a).all()
assert (train_b == eval_b).all()
assert (train_a != train_b).any()
assert (eval_a != eval_b).any()
assert train_a.grad_fn is not None
assert train_b.grad_fn is not None
assert eval_a.grad_fn is not None
assert eval_b.grad_fn is None

I am glad that I identified this, however it still seems obscure to me - at least, the documentation did not lead me into anticipating this behavior. Could someone please let me know whether this is to be expected, and if so perhaps work on documenting it (or point me out to a bit of documentation I might have missed)?

@ZailiWang
Copy link
Contributor

Hi, thanks for your investigation. I tested with 20230709 nightly and the issue is not re-produced. Would you upgrade to a latest nightly and check? thanks.

@williamwen42 williamwen42 added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Jul 14, 2023
@pandrey-fr
Copy link
Author

Hi,
Indeed, the problem is solved in nightly. Should I therefore assume that it is indeed an issue, and wait for the next version to come out?
(in the meanwhile, a dirty fix is to make a first call in eval mode outside of a no_grad context - then, later no_grad-context computations seem to be properly performed)

@ZailiWang
Copy link
Contributor

Sure. And before the new release comes out you can install a recent nightly build. Please refer "INSTALL PYTORCH" section in pytorch.org page for detail install command.

@pandrey-fr
Copy link
Author

Requiring end-users of my own torch-dependent package to install a nightly version is not a very neat solution, so I will stick to my ugly fix. But thank you anyway, and I will be looking forward for the release of the 2.1 version.

pandrey-fr added a commit to DecLearn/declearn that referenced this issue Aug 30, 2023
- Torch 2.0 introduced `torch.compile`, a novel utilty to optimize
  computations via their JIT-compiling into optimized kernels. At
  the moment, compiled modules cannot be saved (but their states,
  which are shared with the underlying original module, can), and
  are not compatible with `torch.func` functional execution either.
- As declearn vows to support Torch 2.0, it seems crucial that end-
  users may use `torch.compile`. However, as long as Torch 1.10-13
  versions (and previous versions of DecLearn 2.X) are supported,
  this handling should not break backward compatibility.
- An initial approach was to enable compiling the handled module as
  part of `TorchModel.__init__`. However, this proves impractical,
  as it takes away the assumption that end-users should be able to
  use their customly-prepared module as-is - including pre-compiled
  ones. This is all the more true as their are many options to the
  `torch.compile` function, that DecLearn has no purpose handling.
- Therefore, the approach implemented here is to detect and handle
  models that were compiled prior to being input into `TorchModel`.

- Here are a few notes and caveats regarding the current implementation:
  - Some impractical steps were taken to ensure weights and gradients
    have the same name regardless of whether the module was compiled
    or not, _and regardless of the specific 2.x version of declearn_.
    When we move to declearn 3, it will be worth revising.
  - A positive consequence of the previous note is that the compilation
    of the module should not impair cases where some clients are using
    torch 1 and/or an older 2.x version of declearn.
  - The will to retain user-defined compilation option is mostly lost
    due to the current lack of recording of these info by torch. This
    is however expected to evolve in the future, which should enable
    sharing instructions with clients. See issue 101107 of pytorch:
    pytorch/pytorch#101107
  - A clumsy bugfix was introduced to avoid an issue where the wrapped
    compiled model would not take weights updates into account when
    running in evaluation mode. The status of this hack should be kept
    under close look as the issue I opened to report the bug is treated:
    pytorch/pytorch#104984
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: pt2
Projects
None yet
Development

No branches or pull requests

4 participants
0