10000 [PT2E][Quant] `prepare_pt2e` model with tied weights will raise `ValueError` when using `export_for_training` · Issue #142035 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[PT2E][Quant] prepare_pt2e model with tied weights will raise ValueError when using export_for_training #142035
@yiliu30

Description

@yiliu30

🐛 Describe the bug

After replacing the export API from capture_pre_autograd_graph to export_for_training, We got a ValueError: Linear partition cannot have more than one output node.

We tested with opt-125, which tied the lm_head.weight to the decoder.embed_tokens.weight. The previous export API creates two separate nodes for embedding and lm_head, while the new export API creates a single node.

Here are the export graphs for your reference.

  • w/ capture_pre_autograd_graph
W1204 00:25:55.768000 4063995 site-packages/torch/_export/__init__.py:64] +============================+
W1204 00:25:55.769000 4063995 site-packages/torch/_export/__init__.py:65] |     !!!   WARNING   !!!    |
W1204 00:25:55.770000 4063995 site-packages/torch/_export/__init__.py:66] +============================+
W1204 00:25:55.770000 4063995 site-packages/torch/_export/__init__.py:67] capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.
W1204 00:25:55.771000 4063995 site-packages/torch/_export/__init__.py:68] Please switch to use torch.export.export_for_training instead.
class GraphModule(torch.nn.Module):
    def forward(self, input_ids):
        arg0: "i64[2, 9]"; 
    
        arg0, = fx_pytree.tree_flatten_spec(([input_ids], {}), self._in_spec)
        arg0_1 = arg0
        
         # File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:859 in forward, code: input_ids = input_ids.view(-1, input_shape[-1])
        view: "i64[2, 9]" = torch.ops.aten.view.default(arg0_1, [-1, 9]);  arg0_1 = None
        
         # File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:866 in forward, code: inputs_embeds = self.embed_tokens(input_ids)
        _param_constant0: "f32[50272, 768]" = self.lm_head_weight
        embedding: "f32[2, 9, 768]" = torch.ops.aten.embedding.default(_param_constant0, view, 1);  _param_constant0 = view = None
        
        ...
        # File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:1189 in forward, code: logits = self.lm_head(outputs[0]).contiguous()
        _param_constant0_1: "f32[50272, 768]" = self.lm_head_weight
        linear_12: "f32[2, 9, 50272]" = torch.ops.aten.linear.default(layer_norm_4, _param_constant0_1);  layer_norm_4 = _param_constant0_1 = None
        return pytree.tree_unflatten([linear_12, contiguous_1, contiguous_2, contiguous_5, contiguous_6], self._out_spec)
  • w/ export_for_training
class GraphModule(torch.nn.Module):
    def forward(self, input_ids):
        input_ids: "i64[2, 9]"; 
    
        input_ids, = fx_pytree.tree_flatten_spec(([input_ids], {}), self._in_spec)
        # No stacktrace found for following nodes
        model_decoder_embed_positions_weight: "f32[2050, 768]" = self.model.decoder.embed_positions.weight
        lm_head_weight: "f32[50272, 768]" = self.lm_head.weight;  lm_head_weight = None
        model_decoder_layers_0_self_attn_layer_norm_weight: "f32[768]" = getattr(self.model.decoder.layers, "0").self_attn_layer_norm.weight
        ....
        model_decoder_final_layer_norm_weight: "f32[768]" = self.model.decoder.final_layer_norm.weight
        model_decoder_final_layer_norm_bias: "f32[768]" = self.model.decoder.final_layer_norm.bias
        lm_head_weight_1: "f32[50272, 768]" = self.lm_head.weight
        
         # File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:859 in forward, code: input_ids = input_ids.view(-1, input_shape[-1])
        view: "i64[2, 9]" = torch.ops.aten.view.default(input_ids, [-1, 9]);  input_ids = None
        
         # File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:866 in forward, code: inputs_embeds = self.embed_tokens(input_ids)
        embedding: "f32[2, 9, 768]" = torch.ops.aten.embedding.default(lm_head_weight_1, view, 1);  view = None
        
         # File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:1189 in forward, code: logits = self.lm_head(outputs[0]).contiguous()
        linear_12: "f32[2, 9, 50272]" = torch.ops.aten.linear.default(layer_norm_4, lm_head_weight_1);  layer_norm_4 = lm_head_weight_1 = None
        return pytree.tree_unflatten((linear_12, contiguous_1, contiguous_2, contiguous_5, contiguous_6), self._out_spec)

Code

import torch
import transformers
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import os

with torch.no_grad():
    opt_config = transformers.OPTConfig(num_hidden_layers=2)
    opt_model = transformers.AutoModelForCausalLM.from_config(opt_config)
    opt_model.eval()
    assert torch.equal(
        opt_model.model.decoder.embed_tokens.weight, opt_model.lm_head.weight
    ), f"Expected lm_head weight is tied to the embed tokens weight"

    example_inputs = (torch.randint(0, opt_config.vocab_size, (2, 9), dtype=torch.int64),)
    float_out = opt_model(*example_inputs)

    if os.environ.get("NEW_EXPORT", "1") == "1":
        exported_model = torch.export.export_for_training(
            opt_model,
            args=example_inputs,
        ).module()
    else:
        exported_model = torch._export.capture_pre_autograd_graph(opt_model, args=example_inputs)
    exported_model.print_readable()
    quantizer = xiq.X86InductorQuantizer()
    quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
    prepared_model = prepare_pt2e(exported_model, quantizer)

Versions

Collecting environment information...
PyTorch version: 2.6.0.dev20241013
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (conda-forge gcc 11.4.0-13) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.5
Libc version: glibc-2.31

Python version: 3.11.10 | packaged by conda-forge | (main, Sep 10 2024, 11:01:28) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-169-generic-x86_64-with-glibc2.31
Is CUDA available: True
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Stepping:                           0
Frequency boost:                    enabled
CPU MHz:                            2955.101
CPU max MHz:                        2250.0000
CPU min MHz:                        1500.0000
BogoMIPS:                           4500.26
Virtualization:                     AMD-V
L1d cache:                          4 MiB
L1i cache:                          4 MiB
L2 cache:                           64 MiB
L3 cache:                           512 MiB
NUMA node0 CPU(s):                  0-63,128-191
NUMA node1 CPU(s):                  64-127,192-255

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] torch==2.6.0.dev20241013
[pip3] torchao==0.7.0+gitb2e42ff6
[pip3] torchaudio==2.5.0.dev20241013
[pip3] torchvision==0.20.0.dev20241013
[pip3] triton==3.1.0
[conda] blas                      2.116                       mkl    conda-forge
[conda] blas-devel                3.9.0            16_linux64_mkl    conda-forge
[conda] brotlipy                  0.7.0           py311h9bf148f_1002    pytorch-nightly
[conda] cffi                      1.15.1          py311h9bf148f_3    pytorch-nightly
[conda] cryptography              38.0.4          py311h46ebde7_0    pytorch-nightly
[conda] cuda-cudart               12.4.127                      0    nvidia
[conda] cuda-cupti                12.4.127                      0    nvidia
[conda] cuda-libraries            12.4.1                        0    nvidia
[conda] cuda-nvrtc                12.4.127                      0    nvidia
[conda] cuda-nvtx                 12.4.127                      0    nvidia
[conda] cuda-opencl               12.6.68                       0    nvidia
[conda] cuda-runtime              12.4.1                        0    nvidia
[conda] filelock                  3.9.0                   py311_0    pytorch-nightly
[conda] libblas                   3.9.0            16_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            16_linux64_mkl    conda-forge
[conda] libcublas                 12.4.5.8                      0    nvidia
[conda] libcufft                  11.2.1.3                      0    nvidia
[conda] libcurand                 10.3.7.68                     0    nvidia
[conda] libcusolver               11.6.1.9                      0    nvidia
[conda] libcusparse               12.3.1.170                    0    nvidia
[conda] liblapack                 3.9.0            16_linux64_mkl    conda-forge
[conda] liblapacke                3.9.0            16_linux64_mkl    conda-forge
[conda] libnvjitlink              12.4.127                      0    nvidia
[conda] mkl                       2022.1.0           h84fe81f_915    conda-forge
[conda] mkl-devel                 2022.1.0           ha770c72_916    conda-forge
[conda] mkl-include               2022.1.0           h84fe81f_915    conda-forge
[conda] mpmath                    1.2.1                   py311_0    pytorch-nightly
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] pysocks                   1.7.1                   py311_0    pytorch-nightly
[conda] pytorch                   2.6.0.dev20241013 py3.11_cuda12.4_cudnn9.1.0_0    pytorch-nightly
[conda] pytorch-cuda              12.4                 hc786d27_7    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] requests                  2.28.1                  py311_0    pytorch-nightly
[conda] torchao                   0.7.0+gitb2e42ff6           dev_0    <develop>
[conda] torchaudio                2.5.0.dev20241013     py311_cu124    pytorch-nightly
[conda] torchtriton               3.1.0+cf34004b8a           py311    pytorch-nightly
[conda] torchvision               0.20.0.dev20241013     py311_cu124    pytorch-nightly
[conda] urllib3                   1.26.14                 py311_0    pytorch-nightly

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

Labels

export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steponcall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0