-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Closed
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steponcall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
After replacing the export API from capture_pre_autograd_graph
to export_for_training
, We got a ValueError: Linear partition cannot have more than one output node
.
We tested with opt-125, which tied the lm_head.weight
to the decoder.embed_tokens.weight
. The previous export API creates two separate nodes for embedding
and lm_head
, while the new export API creates a single node.
Here are the export graphs for your reference.
- w/
capture_pre_autograd_graph
W1204 00:25:55.768000 4063995 site-packages/torch/_export/__init__.py:64] +============================+
W1204 00:25:55.769000 4063995 site-packages/torch/_export/__init__.py:65] | !!! WARNING !!! |
W1204 00:25:55.770000 4063995 site-packages/torch/_export/__init__.py:66] +============================+
W1204 00:25:55.770000 4063995 site-packages/torch/_export/__init__.py:67] capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.
W1204 00:25:55.771000 4063995 site-packages/torch/_export/__init__.py:68] Please switch to use torch.export.export_for_training instead.
class GraphModule(torch.nn.Module):
def forward(self, input_ids):
arg0: "i64[2, 9]";
arg0, = fx_pytree.tree_flatten_spec(([input_ids], {}), self._in_spec)
arg0_1 = arg0
# File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:859 in forward, code: input_ids = input_ids.view(-1, input_shape[-1])
view: "i64[2, 9]" = torch.ops.aten.view.default(arg0_1, [-1, 9]); arg0_1 = None
# File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:866 in forward, code: inputs_embeds = self.embed_tokens(input_ids)
_param_constant0: "f32[50272, 768]" = self.lm_head_weight
embedding: "f32[2, 9, 768]" = torch.ops.aten.embedding.default(_param_constant0, view, 1); _param_constant0 = view = None
...
# File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:1189 in forward, code: logits = self.lm_head(outputs[0]).contiguous()
_param_constant0_1: "f32[50272, 768]" = self.lm_head_weight
linear_12: "f32[2, 9, 50272]" = torch.ops.aten.linear.default(layer_norm_4, _param_constant0_1); layer_norm_4 = _param_constant0_1 = None
return pytree.tree_unflatten([linear_12, contiguous_1, contiguous_2, contiguous_5, contiguous_6], self._out_spec)
- w/
export_for_training
class GraphModule(torch.nn.Module):
def forward(self, input_ids):
input_ids: "i64[2, 9]";
input_ids, = fx_pytree.tree_flatten_spec(([input_ids], {}), self._in_spec)
# No stacktrace found for following nodes
model_decoder_embed_positions_weight: "f32[2050, 768]" = self.model.decoder.embed_positions.weight
lm_head_weight: "f32[50272, 768]" = self.lm_head.weight; lm_head_weight = None
model_decoder_layers_0_self_attn_layer_norm_weight: "f32[768]" = getattr(self.model.decoder.layers, "0").self_attn_layer_norm.weight
....
model_decoder_final_layer_norm_weight: "f32[768]" = self.model.decoder.final_layer_norm.weight
model_decoder_final_layer_norm_bias: "f32[768]" = self.model.decoder.final_layer_norm.bias
lm_head_weight_1: "f32[50272, 768]" = self.lm_head.weight
# File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:859 in forward, code: input_ids = input_ids.view(-1, input_shape[-1])
view: "i64[2, 9]" = torch.ops.aten.view.default(input_ids, [-1, 9]); input_ids = None
# File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:866 in forward, code: inputs_embeds = self.embed_tokens(input_ids)
embedding: "f32[2, 9, 768]" = torch.ops.aten.embedding.default(lm_head_weight_1, view, 1); view = None
# File: /home/xxx/transformers/src/transformers/models/opt/modeling_opt.py:1189 in forward, code: logits = self.lm_head(outputs[0]).contiguous()
linear_12: "f32[2, 9, 50272]" = torch.ops.aten.linear.default(layer_norm_4, lm_head_weight_1); layer_norm_4 = lm_head_weight_1 = None
return pytree.tree_unflatten((linear_12, contiguous_1, contiguous_2, contiguous_5, contiguous_6), self._out_spec)
Code
import torch
import transformers
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import os
with torch.no_grad():
opt_config = transformers.OPTConfig(num_hidden_layers=2)
opt_model = transformers.AutoModelForCausalLM.from_config(opt_config)
opt_model.eval()
assert torch.equal(
opt_model.model.decoder.embed_tokens.weight, opt_model.lm_head.weight
), f"Expected lm_head weight is tied to the embed tokens weight"
example_inputs = (torch.randint(0, opt_config.vocab_size, (2, 9), dtype=torch.int64),)
float_out = opt_model(*example_inputs)
if os.environ.get("NEW_EXPORT", "1") == "1":
exported_model = torch.export.export_for_training(
opt_model,
args=example_inputs,
).module()
else:
exported_model = torch._export.capture_pre_autograd_graph(opt_model, args=example_inputs)
exported_model.print_readable()
quantizer = xiq.X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
prepared_model = prepare_pt2e(exported_model, quantizer)
Versions
Collecting environment information...
PyTorch version: 2.6.0.dev20241013
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (conda-forge gcc 11.4.0-13) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.5
Libc version: glibc-2.31
Python version: 3.11.10 | packaged by conda-forge | (main, Sep 10 2024, 11:01:28) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-169-generic-x86_64-with-glibc2.31
Is CUDA available: True
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Stepping: 0
Frequency boost: enabled
CPU MHz: 2955.101
CPU max MHz: 2250.0000
CPU min MHz: 1500.0000
BogoMIPS: 4500.26
Virtualization: AMD-V
L1d cache: 4 MiB
L1i cache: 4 MiB
L2 cache: 64 MiB
L3 cache: 512 MiB
NUMA node0 CPU(s): 0-63,128-191
NUMA node1 CPU(s): 64-127,192-255
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] torch==2.6.0.dev20241013
[pip3] torchao==0.7.0+gitb2e42ff6
[pip3] torchaudio==2.5.0.dev20241013
[pip3] torchvision==0.20.0.dev20241013
[pip3] triton==3.1.0
[conda] blas 2.116 mkl conda-forge
[conda] blas-devel 3.9.0 16_linux64_mkl conda-forge
[conda] brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly
[conda] cffi 1.15.1 py311h9bf148f_3 pytorch-nightly
[conda] cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly
[conda] cuda-cudart 12.4.127 0 nvidia
[conda] cuda-cupti 12.4.127 0 nvidia
[conda] cuda-libraries 12.4.1 0 nvidia
[conda] cuda-nvrtc 12.4.127 0 nvidia
[conda] cuda-nvtx 12.4.127 0 nvidia
[conda] cuda-opencl 12.6.68 0 nvidia
[conda] cuda-runtime 12.4.1 0 nvidia
[conda] filelock 3.9.0 py311_0 pytorch-nightly
[conda] libblas 3.9.0 16_linux64_mkl conda-forge
[conda] libcblas 3.9.0 16_linux64_mkl conda-forge
[conda] libcublas 12.4.5.8 0 nvidia
[conda] libcufft 11.2.1.3 0 nvidia
[conda] libcurand 10.3.7.68 0 nvidia
[conda] libcusolver 11.6.1.9 0 nvidia
[conda] libcusparse 12.3.1.170 0 nvidia
[conda] liblapack 3.9.0 16_linux64_mkl conda-forge
[conda] liblapacke 3.9.0 16_linux64_mkl conda-forge
[conda] libnvjitlink 12.4.127 0 nvidia
[conda] mkl 2022.1.0 h84fe81f_915 conda-forge
[conda] mkl-devel 2022.1.0 ha770c72_916 conda-forge
[conda] mkl-include 2022.1.0 h84fe81f_915 conda-forge
[conda] mpmath 1.2.1 py311_0 pytorch-nightly
[conda] numpy 1.23.5 pypi_0 pypi
[conda] pysocks 1.7.1 py311_0 pytorch-nightly
[conda] pytorch 2.6.0.dev20241013 py3.11_cuda12.4_cudnn9.1.0_0 pytorch-nightly
[conda] pytorch-cuda 12.4 hc786d27_7 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] requests 2.28.1 py311_0 pytorch-nightly
[conda] torchao 0.7.0+gitb2e42ff6 dev_0 <develop>
[conda] torchaudio 2.5.0.dev20241013 py311_cu124 pytorch-nightly
[conda] torchtriton 3.1.0+cf34004b8a py311 pytorch-nightly
[conda] torchvision 0.20.0.dev20241013 py311_cu124 pytorch-nightly
[conda] urllib3 1.26.14 py311_0 pytorch-nightly
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
Metadata
Metadata
Assignees
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steponcall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module