8000 Support complex type in ONNX export · Issue #59246 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Support complex type in ONNX export #59246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Support complex type in ONNX export #59246

david-macleod opened this issue Jun 1, 2021 · 17 comments
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@david-macleod
Copy link
david-macleod commented Jun 1, 2021

🐛 Bug

When exporting a model with torch.onnx.export I receive the following error

File "/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 709, in _export
    proto, export_map = graph._export_onnx(
RuntimeError: unexpected tensor scalar type

The error is being raised from here and I am guessing the cause of the issue is that the return type from the op is ComplexFloat (see trace beow) which doesn't seem to be covered by the switch statement (unless that is actually built from a more primitive type).

graph(%x.1 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu)):
%1 : ComplexFloat(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = com.microsoft.experimental::DFT(%x.1) # test.py:45:15
return (%1)

However I also see reference to ComplexFloat in symbolic_helper.py here which suggests it it supported?

To Reproduce

import torch
from torch.onnx import register_custom_op_symbolic

class Model(torch.nn.Module):

    def forward(self, x):
        return torch.fft.fft(x)

def fft(g, self, n, dim, norm):
    return g.op("com.microsoft.experimental::DFT", self)
register_custom_op_symbolic("::fft_fft", fft, 1)


model = Model()
ts_model = torch.jit.script(model)

data = torch.randn(1, 1024)
y = ts_model(data)

torch.onnx.export(
    ts_model,
    (data,),
    "tmp.onnx",
    opset_version=13,
    verbose=True,
    example_outputs=(y,),
)

Expected behavior

Successful export of ONNX graph with ComplexFloat output type.

Environment

PyTorch version: 1.8.1+cu102                                                                                                                                              Is debug build: False                                                                                                                                                     CUDA used to build PyTorch: 10.2                                                                                                                                          ROCM used to build PyTorch: N/A                                                                                                                                                                                                                                                                                                                     OS: Ubuntu 18.04.5 LTS (x86_64)                                                                                                                                           GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0                                                                                                                          Clang version: Could not collect                                                                                                                                          CMake version: version 3.18.4                                                                                                                                             Libc version: glibc-2.27                                                                                                                                                                                                                                                                                                                            Python version: 3.8 (64-bit runtime)                                                                                                                                      Python platform: Linux-5.4.0-70-generic-x86_64-with-glibc2.27                                                                                                             Is CUDA available: True
CUDA runtime version: 11.2.152
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti
GPU 4: GeForce RTX 2080 Ti
GPU 5: GeForce RTX 2080 Ti
GPU 6: GeForce RTX 2080 Ti
GPU 7: GeForce RTX 2080 Ti
GPU 8: GeForce RTX 2080 Ti
GPU 9: GeForce RTX 2080 Ti

Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1

cc @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof @SplitInfinity

@ngimel ngimel added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 1, 2021
@david-macleod
Copy link
Author
david-macleod commented Jun 2, 2021

I have confirmed this is definitely a complex number issue as I get the same error substituting the following into the previous example

def model(self):
    return torch.complex(torch.tensor(1.), torch.tensor(1.))

@BowenBao
Copy link
Collaborator
BowenBao commented Jun 2, 2021

@david-macleod Thanks for reporting! Indeed this is where the error arises, and there could be other places that need to be checked and updated, like shape type inference and symbolic functions, to support complex computation. We will take a look and evaluate what are needed.

@sunwell1994
Copy link

Is there any way to export the network with complex float now?

@BowenBao
Copy link
Collaborator
BowenBao commented Oct 5, 2021

@sunwell1994 This is being tracked internally. We are evaluating the design for both exporter and ONNX wrt complex types and operators. Could you try working around the issue with unwrapping complex type to two separate tensors for the moment?

@BowenBao BowenBao self-assigned this Oct 5, 2021
@garymm garymm changed the title Unexpected tensor scalar type with torch.onnx.export and ComplexFloat Support complex type in ONNX export Oct 7, 2021
@garymm
Copy link
Collaborator
garymm commented Oct 7, 2021

Tracked internally at Microsoft by https://msdata.visualstudio.com/Vienna/_workitems/edit/1442180

@thiagocrepaldi
Copy link
Collaborator

Complex types is a known limitation for the ONNX exporter. Although there is no estimate on when this could be resolved, it is being tracked internally

@BowenBao BowenBao removed their assignment Jan 4, 2023
@abock abock added this to ONNX Jun 14, 2023
@github-project-automation github-project-automation bot moved this to Inbox in ONNX Jun 14, 2023
@sammysun0711
Copy link
sammysun0711 commented Jun 15, 2023

Any update about this issue?
I met same issue during export LLM to ONNX: https://huggingface.co/qhduan/aquilachat-7b/discussions/2

@SchweitzerGAO
Copy link

Any updates?

@justinchuby
Copy link
Collaborator

Complex support is added in the torch.onnx.dynamo_export exporter. FFT support will be added. Duplicated of #107588

@justinchuby justinchuby closed this as not planned Won't fix, can't repro, duplicate, stale Aug 24, 2023
@github-project-automation github-project-automation bot moved this from Inbox to Done in ONNX Aug 24, 2023
@bigmover
Copy link

@david-macleod Thanks for reporting! Indeed this is where the error arises, and there could be other places that need to be checked and updated, like shape type inference and symbolic functions, to support complex computation. We will take a look and evaluate what are needed.

Ops related to FFT may have the error. It can be supported now? Whether can be solved with some method?

@bigmover
Copy link

Complex types is a known limitation for the ONNX exporter. Although there is no estimate on when this could be resolved, it is being tracked internally

I came to the same error when convert LaMa model to onnx . And I think the method is related to FFT ops. I tried to use torch.onnx.dynamo_export to solve it. But I found pytorch 2.0.1 can't support dynamo_export. Any advice to solve it. Thanks!

@justinchuby
Copy link
Collaborator
justinchuby commented Jan 15, 2024 via email

@TonyPolich
Copy link

I'm on 2.3.0.dev20240219+cu121 and I still get the same error.

@lix19937
Copy link
lix19937 commented May 4, 2024

+1

@tianyic
Copy link
tianyic commented Jun 5, 2024

+1, any update on the complex support?

@songh11
Copy link
songh11 commented Jul 26, 2024

+1, torch2.3.0 use torch.onnx.export also had this problem.

@bhack
Copy link
Contributor
bhack commented Aug 22, 2024

Is it supported with https://pytorch.org/docs/stable/onnx_dynamo.html or not?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

0