8000 [ONNX] Converting a model with STFT fails · Issue #106850 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Converting a model with STFT fails #106850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
WangHHY19931001 opened this issue Aug 9, 2023 · 5 comments
Closed

[ONNX] Converting a model with STFT fails #106850

WangHHY19931001 opened this issue Aug 9, 2023 · 5 comments
Labels
low priority We're unlikely to get around to doing this in the near future module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@WangHHY19931001
Copy link
WangHHY19931001 commented Aug 9, 2023

🐛 Describe the bug

code:

import torch.onnx
import torchaudio
from torch import nn


class DataCov(nn.Module):
    def __init__(self):
        super(DataCov, self).__init__()

        self.transform = nn.Sequential(
            torchaudio.transforms.MelSpectrogram(sample_rate=48000, n_fft=1536, hop_length=768, f_min=20, f_max=20000)
        )

    def forward(self, x1):
        return self.transform(x1)



if __name__ == '__main__':
    model = DataCov()
    model.eval()
    model = torch.jit.script(model)

    x = torch.randn(1, 48000 * 12, requires_grad=True)
    args = (x,)
    torch.onnx.export(model, args, 'DataCov.onnx', export_params=True, opset_version=15)

error:

C:\Users\dell\miniconda3\envs\mos_cov_to_android\Lib\site-packages\torch\onnx\_internal\jit_utils.py:306: UserWarning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\jit\passes\onnx\constant_fold.cpp:181.)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
================ Diagnostic Run torch.onnx.export version 2.0.1 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 1 ERROR ========================
ERROR: missing-standard-symbolic-function
=========================================
Exporting the operator 'aten::stft' to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
None
<Set verbose=True to see more details>


Traceback (most recent call last):
  File "D:\mos_cov_to_android\main.py", line 189, in <module>
    export_datacov_onnx('DataCov.onnx')
  File "D:\mos_cov_to_android\main.py", line 183, in export_datacov_onnx
    torch.onnx.export(model, args, path, export_params=True, opset_version=17)
  File "C:\Users\dell\miniconda3\envs\mos_cov_to_android\Lib\site-packages\torch\onnx\utils.py", line 506, in export
    _export(
  File "C:\Users\dell\miniconda3\envs\mos_cov_to_android\Lib\site-packages\torch\onnx\utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "C:\Users\dell\miniconda3\envs\mos_cov_to_android\Lib\site-packages\torch\onnx\utils.py", line 1117, in _model_to_graph
    graph = _optimize_graph(
            ^^^^^^^^^^^^^^^^
  File "C:\Users\dell\miniconda3\envs\mos_cov_to_android\Lib\site-packages\torch\onnx\utils.py", line 665, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\dell\miniconda3\envs\mos_cov_to_android\Lib\site-packages\torch\onnx\utils.py", line 1901, in _run_symbolic_function
    raise errors.UnsupportedOperatorError(
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::stft' to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

Versions

Collecting environment information...
PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 专业版
GCC version: (Rev5, Built by MSYS2 project) 13.1.0
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: N/A

Python version: 3.10.12 | packaged by Anaconda, Inc. | (main, Jul  5 2023, 19:01:18) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19045-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:










Revision=

Versions of relevant libraries:
[pip3] numpy==1.25.0
[pip3] torch==2.0.1
[pip3] torchaudio==2.0.2
[pip3] torchvision==0.15.2
[conda] blas                      1.0                         mkl    defaults
[conda] cpuonly                   2.0                           0    pytorch
[conda] mkl                       2023.1.0         h8bd8f75_46356    defaults
[conda] mkl-service               2.4.0           py310h2bbff1b_1    defaults
[conda] mkl_fft                   1.3.6           py310h4ed8f06_1    defaults
[conda] mkl_random                1.2.2           py310h4ed8f06_1    defaults
[conda] numpy                     1.25.1                   pypi_0    pypi
[conda] numpy-base                1.25.0          py310h65a83cf_0    defaults
[conda] pytorch                   2.0.1              py3.10_cpu_0    pytorch
[conda] pytorch-mutex             1.0                         cpu    pytorch
[conda] torchaudio                2.0.2                 py310_cpu    pytorch
[conda] torchvision               0.15.2                py310_cpu    pytorch
@WangHHY19931001
Copy link
Author

follow #106505 still error

@justinchuby justinchuby changed the title cover to onnx error [ONNX] Converting a model with STFT fails Aug 9, 2023
@justinchuby justinchuby self-assigned this Aug 9, 2023
@justinchuby
Copy link
Collaborator

I see you are using opset_version 15. Changing to opset_version 17 reveals another error:

SymbolicValueError: STFT does not currently support complex types  [Caused by the value 'input.3 defined in (%input.3 : Float(*, *, device=cpu) = onnx::Reshape[allowzero=0](%input, %75), scope: DataCov::/torch.nn.modules.container.Sequential::transform # /home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py:649:16
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Reshape'.] 
    (node defined in   File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py", line 649
        pad = int(n_fft // 2)
        input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
        input = input.view(input.shape[-signal_dim:])
                ~~~~~~~~~~ <--- HERE
    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
                    normalized, onesided, return_complex)
)

    Inputs:
        #0: input defined in (%input : Float(*, *, *, device=cpu) = onnx::Pad[mode="reflect"](%40, %64), scope: DataCov::/torch.nn.modules.container.Sequential::transform # /home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py:648:16
    )  (type 'Tensor')
        #1: 75 defined in (%75 : Long(2, strides=[1], device=cpu) = onnx::Slice(%66, %71, %72, %69, %74), scope: DataCov::/torch.nn.modules.container.Sequential::transform # /home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py:649:27
    )  (type 'Tensor')
    Outputs:
        #0: input.3 defined in (%input.3 : Float(*, *, device=cpu) = onnx::Reshape[allowzero=0](%input, %75), scope: DataCov::/torch.nn.modules.container.Sequential::transform # /home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py:649:16
    )  (type 'Tensor')

This is known and I expect we will support it with the new torch.onnx.dynamo_export by PyTorch 2.1.

@justinchuby justinchuby removed their assignment Aug 9, 2023
@justinchuby justinchuby added the module: onnx Related to torch.onnx label Aug 9, 2023
@github-project-automation github-project-automation bot moved this to Inbox in ONNX Aug 9, 2023
@justinchuby justinchuby added the low priority We're unlikely to get around to doing this in the near future label Aug 9, 2023
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 9, 2023
@WangHHY19931001
Copy link
Author

I see you are using opset_version 15. Changing to opset_version 17 reveals another error:

SymbolicValueError: STFT does not currently support complex types  [Caused by the value 'input.3 defined in (%input.3 : Float(*, *, device=cpu) = onnx::Reshape[allowzero=0](%input, %75), scope: DataCov::/torch.nn.modules.container.Sequential::transform # /home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py:649:16
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Reshape'.] 
    (node defined in   File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py", line 649
        pad = int(n_fft // 2)
        input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
        input = input.view(input.shape[-signal_dim:])
                ~~~~~~~~~~ <--- HERE
    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
                    normalized, onesided, return_complex)
)

    Inputs:
        #0: input defined in (%input : Float(*, *, *, device=cpu) = onnx::Pad[mode="reflect"](%40, %64), scope: DataCov::/torch.nn.modules.container.Sequential::transform # /home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py:648:16
    )  (type 'Tensor')
        #1: 75 defined in (%75 : Long(2, strides=[1], device=cpu) = onnx::Slice(%66, %71, %72, %69, %74), scope: DataCov::/torch.nn.modules.container.Sequential::transform # /home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py:649:27
    )  (type 'Tensor')
    Outputs:
        #0: input.3 defined in (%input.3 : Float(*, *, device=cpu) = onnx::Reshape[allowzero=0](%input, %75), scope: DataCov::/torch.nn.modules.container.Sequential::transform # /home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py:649:16
    )  (type 'Tensor')

This is known and I expect we will support it with the new torch.onnx.dynamo_export by PyTorch 2.1.

so, i need install day build version?

@justinchuby
Copy link
Collaborator

Yes. We recommend installing torch-nightly

@justinchuby
Copy link
Collaborator

Duplicate of #107588

@justinchuby justinchuby marked this as a duplicate of #107588 Aug 24, 2023
@justinchuby justinchuby closed this as not planned Won't fix, can't repro, duplicate, stale Aug 24, 2023
@github-project-automation github-project-automation bot moved this from Inbox to Done in ONNX Aug 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
low priority We're unlikely to get around to doing this in the near future module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

3 participants
0