8000 [feature request] Support native ONNX export of FFT-related ops in opset17 (with `inverse=True`, it also includes inverse DFT) · Issue #107588 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[feature request] Support native ONNX export of FFT-related ops in opset17 (with inverse=True, it also includes inverse DFT) #107588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vadimkantorov opened this issue Aug 17, 2023 · 28 comments
Assignees
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Aug 17, 2023

🚀 The feature

Seems that more FFT-related ops are now supported by ONNX opset17 and onnxruntime:

It would be good to have torch.fft.rfft and friends to be natively exported without having to construct the basis manually.

I think this would also be import to torchaudio community

Also related: #107446

Motivation, pitch

N/A

Alternatives

No response

Additional context

No response

@vadimkantorov vadimkantorov changed the title Support native ONNX export of FFT-related ops in opset17 [feature request] Support native ONNX export of FFT-related ops in opset17 Aug 17, 2023
@vadimkantorov vadimkantorov changed the title [feature request] Support native ONNX export of FFT-related ops in opset17 [feature request] Support native ONNX export of FFT-related ops in opset17 (with inverse=True, it also includes inverse DFT) Aug 21, 2023
@vadimkantorov
Copy link
Contributor Author

Or maybe could you please transfer this issue to the core repo issues?

@mthrok mthrok transferred this issue from pytorch/audio Aug 21, 2023
@cpuhrsch cpuhrsch added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 22, 2023
@github-project-automation github-project-automation bot moved this to Inbox in ONNX Aug 22, 2023
@justinchuby justinchuby self-assigned this Aug 24, 2023
@justinchuby
Copy link
Collaborator

FFT and STFT will be supported by the onnx.dynamo_export exporter

@WangHHY19931001
Copy link

FFT and STFT will be supported by the onnx.dynamo_export exporter

see #107922

@vadimkantorov
Copy link
Contributor Author

@github-project-automation github-project-automation bot moved this from Inbox to Done in ONNX Nov 7, 2023
@vadimkantorov
Copy link
Contributor Author
vadimkantorov commented Feb 6, 2024

@justinchuby Would it be possible to "backport" support for DFT ops into torch.onnx.export e.g. support enable DFT-17?

Or maybe somehow have an opt-in only module enabling these operators for opset17 (via torch.onnx.register_custom_op_symbolic) or introduce some rudimentary support of opset18/opset20 into torch.onnx.export?

Currently, dynamo_export has some questionable limitations for recurrent models (or even recurrent models single steps):

which prevent its successful use in some scenarios requiring some recurrent models / cells export

cc @grazder

@justinchuby
Copy link
Collaborator

some rudimentary support of opset18/opset20

That is possible, but requires refactoring that the team may or may not be able to prioritize. Specifically, we need to use the new Reduce* ops that changed the axis attribute into an input across implementations in different ops.

It is possible to reuse the onnxscript-torchlib implementations.

support for DFT ops

The issue with DFT in torch.onnx.export is that the values are complex numbers, which are not handled. However it is possible to case on the value type and create a real representation in the ONNX graph similar to how the dynamo exporter handles them.

We welcome contribution for either case if it is a priority for you. Thanks!

@vadimkantorov
Copy link
Contributor Author
vadimkantorov commented Feb 6, 2024

@justinchuby I wonder if it's possible to somehow allow in torch.onnx.export of subgraph .rfft(x).view_as_real() only as first workaround? this should at least unblock the typical scenarios where is some fft + magnitude or angle calculations or some masking (the calculations would have to be in (*, 2) real tensors for this, but sometimes it's okay for a workaround)

@vadimkantorov
Copy link
Contributor Author
vadimkantorov commented Aug 18, 2024

Seems that TensorProto.COMPLEX64 exists in https://onnx.ai/onnx/api/mapping.html? Or is it somehow only supported in serialization but not as dtype?

If supported as a dtype, what could be done is probably:

  • adapt existing ONNX operators taking (..., 2)-shaped complex inputs to also take (...) if the input is TensorProto.COMPLEX64
  • add a flag for operators producing complex outputs to optionally produce (...) instead of (..., 2), similar to how pytorch gradually migrated to complex dtypes with an explicit arg return_complex = True

Or maybe alternatively what could be introduced is a dtype float32x2 which would accent the 2-tuple aspect rather than the complex semantics, then this tuple dtype can be used in other contexts as well.

Maybe these complex interleaved dtypes are not the most efficient representation for all ops, but probably it would allow to easier backport complex numbers support to the older export mechanism, while dynamo_export is getting more stable, and maybe make semantics more similar to PyTorch which should alleviate dim recalculations...

@justinchuby
Copy link
Collaborator
justinchuby commented Aug 18, 2024

Even though the complex types exist, there are no operators defined to do computation on them; so only supporting the types themselves does not mean much.

@justinchuby
Copy link
Collaborator

The alternative may be possible. One issue with the old exporter I think is actually knowing the types of the inputs. If we know that the types are complex, it is easy to treat them so. But for ops like abs(), when the input type is missing we don’t know which logic we should use (complex or real)

@vadimkantorov
Copy link
Contributor Author
vadimkantorov commented Aug 18, 2024

I think adding such support to the operators would simplify all export (old an new) wrt shape/dim adjustment calulation and in general make it simpler because semantics is closer to PyTorch... Or maybe making some separate named versions of existing operators working with these complex dtypes (don't know which versioning/evolution strategy is viewed better by the onnx community)...

is ONNX relying on propagating the dtypes in static way? cause otherwise, maybe it could dynamically change its behavior inside the Abs operator impl... e.g. Abs is maybe not so difficult as it should produce a real tensor in any case, so if it's receiving a complex-dtyped tensor as input at runtime, it should always produce a real tensor still...

It could also be that in the future some other complex representations are experimented with in PyTorch (e.g. keeping the real and imag tensors separate as it makes complex gemm reuse existing efficient fp8 kernels)... So IMO it's normal that representations tend to evolve and several can exist to enable efficiency in different memory layout situations... This is to say that maybe (..., 2) (with arbitrary strides) or (2, ...) representations are not inherently bad by themselves...

the new exporter is also getting some mixed feedback at:

  • ONNX Model done advimman/lama#315 (comment)
    If you compare lama_fp32.onnx and lama.onnx models using [Netron](https://netron.app/), you will notice significant differences in architecture. The model exported via torch.dynamo_export is not suitable for use due to a lack of support in other ONNX converters to other formats, low speed, and the inability to use the model on a GPU. These are the limitations of the new Torch exporter. Therefore, there is no benefit in exporting through it.

@justinchuby
Copy link
Collaborator
justinchuby commented Aug 18, 2024

I think adding complex support to the operators

I agree. This has been moving very slowly (or at all). Some suggestions are welcomed!

propagating the dtypes in static way

Could you explain? If you are talking about runtime behavior, I expect, for example, if/when Abs supports complex inputs, the runtime will just do the right thing.

the new exporter is also getting some mixed feedback

We are very aware of this. It will be further improved in torch 2.5.

@vadimkantorov
Copy link
Contributor Author

Could you explain?

I guess I was just misunderstanding your point that not availability of dtype info for a given operator at export time is problematic...

@vadimkantorov
Copy link
Contributor Author
vadimkantorov commented Aug 18, 2024

Some suggestions are welcomed

I guess first - basic ops related to FFT's / abs / pointwise ops. And then gemms and maybe some matrix ops - like mm/mv/eig/svd/linalg.solve...

@justinchuby
Copy link
Collaborator

Right now, the strategy for handling complex op in the new exporter is that based on the tensor type (real, complex), we dispatch that to a different ONNX decomposition logic.

For example, a real valued abs is

def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
    """abs(Tensor self) -> Tensor"""

    return op.Abs(self)

whereas a complex one is

def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8:
    """abs(Tensor self) -> Tensor"""
    # self_real = self[..., 0]
    self_real = op.Slice(self, [0], [1], axes=[-1])
    # self_imag = self[..., 1]
    self_imag = op.Slice(self, [1], [2], axes=[-1])
    real_pow = op.Pow(self_real, 2)
    imag_pow = op.Pow(self_imag, 2)
    real_plus_imag = op.Add(real_pow, imag_pow)
    return op.Squeeze(op.Sqrt(real_plus_imag), axes=[-1])

Note that we are still taking a real tensor in ONNX, but we treat it as a special tensor where the last axis represents the real/imag parts.

We are able to do this dispatching only by having the tensor type for each of the inputs of the aten ops available.

Now, if ONNX Abs itself supports complex inputs, then this dispatching is no longer needed and we do not need the type information at export time.

@vadimkantorov
Copy link
Contributor Author
vadimkantorov commented Aug 19, 2024

I wonder if then my_complex_dtyped_tensor.sum(dim = -1) code is done if PyTorch, then for ONNX, will it figure out correctly the axis to sum across? as when you append a dummy dim, then negative dimensions need to be also shifted if the complex tensor is being handled. I think introducing proper tupled/complex dtypes in ONNX natively would allow to remove this logic...

As for current impl, probably eltwise .Pow() can be applied on the input real tensor before the slice - saving one kernel call (especially if .Pow() does not change strides or if the input is contig - then there is strictly no difference)

@justinchuby
Copy link
Collaborator
  1. Yes because we dispatch that to the correct implementation that handles negative axes (essentially turning it to -2). I can see that having complex type support would make this easier and more intuitive. cc @gramalingam
  2. Thanks for the recommendation!

@justinchuby
Copy link
Collaborator

@gramalingam I also wonder if 2 slices + add is going to be faster then reduce sum.

@justinchuby justinchuby added this to the 2.5.0 milestone Sep 1, 2024
@justinchuby justinchuby reopened this Sep 1, 2024
@github-project-automation github-project-automation bot moved this from Done to Reopened in ONNX Sep 1, 2024
@gramalingam
Copy link

@gramalingam I also wonder if 2 slices + add is going to be faster then reduce sum.

A quick response: this means that two values are being added (in parallel across all tensor indices)? Using ReduceSum for adding 2 elements does not seem efficient. Usually ReduceSum is used for adding many elements together.

This may just mean I haven't understood the context or the question correctly. I should probably read the whole thread.

@gramalingam
Copy link
gramalingam commented Sep 2, 2024

@gramalingam I also wonder if 2 slices + add is going to be faster then reduce sum.

Ok, I went through the thread. I don't understand what approach using ReduceSum you are referring to. If you are referring to the implementation shown above for aten_abs_complex, I don't know how you would use ReduceSum here.

The implementation itself seems reasonable, don't see how it can be improved at this level.

@justinchuby
Copy link
Collaborator

Sounds good, thanks!

@gramalingam
Copy link

is ONNX relying on propagating the dtypes in static way?

Yes. Philosophically, ONNX very much assumes that types can be determined statically and that type-based dispatch can be done statically (or ahead of execution time). (Eg., onnxruntime has no support for dynamic, type-based, dispatch inside the operator kernel.)

@gramalingam
Copy link

@gramalingam I also wonder if 2 slices + add is going to be faster then reduce sum.

Ok, I went through the thread. I don't understand what approach using ReduceSum you are referring to. If you are referring to the implementation shown above for aten_abs_complex, I don't know how you would use ReduceSum here.

The implementation itself seems reasonable, don't see how it can be improved at this level.

Sorry, I understand your suggestion now: you mean adding the square of the real part and the imaginary part. I guess you could take it one step more, and use ReduceL2, eliminating most of the ops. It is an interesting idea and worth doing. The performance could depend a lot on backend optimizations. But I think ReduceL2 gives backends the maximum information to optimize it, and may give better perf. It is better since it exploits data-locality, avoiding multiple iterations through data, etc. But a Reduce implementation will incur overhead for setting up the reduction (which you don't want to do when adding 2 elements).

@gramalingam
Copy link

The alternative may be possible. One issue with the old exporter I think is actually knowing the types of the inputs. If we know that the types are complex, it is easy to treat them so. But for ops like abs(), when the input type is missing we don’t know which logic we should use (complex or real)

I guess you are saying this is a limitation of the old exporter, but not the new one?

@justinchuby
Copy link
Collaborator

@justinchuby justinchuby removed this from the 2.5.0 milestone Sep 5, 2024
@tak2hu
Copy link
tak2hu commented Mar 14, 2025

Edit: I got it working by changing rfftn and irfftn into fftn and ifftn, the rfftn will get error dimension incompatible shapes and fftn won't get any errors

The new dynamo exporter worked for me in Pytorch 2.6 in exporting to a LaMa Image Inpainting model (this one) with this code:

git clone https://github.com/dmMaze/BallonsTranslator
cd BallonsTranslator
from modules import *
paint = inpaint.LamaLarge()
tensor_x = torch.rand((1, 3, 512, 512), dtype=torch.float32)
tensor_y = torch.rand((1, 1, 512, 512), dtype=torch.float32)
onnx_program = torch.onnx.export(paint.model.generator,(tensor_x, tensor_y,),dynamo=True)
onnx_program.optimize()
onnx_program.save('lama_large_512px.onnx')

Then when I want to use the resulted model:

import onnxruntime as ort

sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL

session = ort.InferenceSession("lama_large_512px.onnx", sess_options, providers=['CPUExecutionProvider'])

this loads it albeit with lots of [W:onnxruntime:, graph.cc:109 MergeShapeInfo] Error merging shape info for output. '_fft_c2r_35' source:{1,192,64,33} target:{1,192,64,64}. Falling back to lenient merge..

but won't run with an error like:
[E:onnxruntime:, sequential_executor.cc:572 ExecuteKernel] Non-zero status code returned while running Add node. Name:'node_Add_178' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 33 by 64

and where if optimizations aren't turned off or if I try to convert it into ort format, it will get an error of:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (node_Add_178) Op (Add) [ShapeInferenceError] Incompatible dimensions

the inference code is:

import cv2
import onnxruntime as ort
import numpy as np

sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL

session = ort.InferenceSession("lama_large_512px.onnx", sess_options, providers=['CPUExecutionProvider'])

print("ONNX model loaded successfully!")

# Get model input names and shapes
input_name_img = session.get_inputs()[0].name
input_name_mask = session.get_inputs()[1].name
input_shape = session.get_inputs()[0].shape
input_type = session.get_inputs()[0].type
output_name = session.get_outputs()[0].name

image = cv2.imread("test_image.jpg")
mask_original = cv2.imread("test_mask.png", cv2.IMREAD_GRAYSCALE)

def resize_keepasp(im, new_shape=640, scaleup=True, interpolation=cv2.INTER_LINEAR, stride=None):
    shape = im.shape[:2]
    if new_shape is not None:
        if not isinstance(new_shape, tuple): new_shape = (new_shape, new_shape)
    else: new_shape = shape
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup: r = min(r, 1.0)
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    if stride is not None:
        h, w = new_unpad
        if h % stride != 0 : new_h = (stride - (h % stride)) + h
        else : new_h = h
        if w % stride != 0 : new_w = (stride - (w % stride)) + w
        else : new_w = w
        new_unpad = (new_h, new_w)
    if shape[::-1] != new_unpad: im = cv2.resize(im, new_unpad, interpolation=interpolation)
    return im

def inpaint_preprocess(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
    img_original = np.copy(img)
    mask_original = np.copy(mask)
    mask_original[mask_original < 127] = 0
    mask_original[mask_original >= 127] = 1
    mask_original = mask_original[:, :, None]
    new_shape = 512 if max(img.shape[0: 2]) > 512 else None
    img = resize_keepasp(img, new_shape, stride=64)
    mask = resize_keepasp(mask, new_shape, stride=64)
    im_h, im_w = img.shape[:2]
    longer = max(im_h, im_w)
    pad_bottom = longer - im_h if im_h < longer else 0
    pad_right = longer - im_w if im_w < longer else 0
    mask = cv2.copyMakeBorder(mask, 0, pad_bottom, 0, pad_right, cv2.BORDER_REFLECT)
    img = cv2.copyMakeBorder(img, 0, pad_bottom, 0, pad_right, cv2.BORDER_REFLECT)
    img_input = np.expand_dims(img.transpose(2, 0, 1), axis=0).astype(np.float32) / 255.0
    mask_input = np.expand_dims(mask, axis=(0, 1)).astype(np.float32) / 255.0
    mask_input[mask_input < 0.5] = 0
    mask_input[mask_input >= 0.5] = 1
    img_input *= (1 - mask_input)
    return img_input, mask_input, img_original, mask_original, pad_bottom, pad_right

img_input, mask_input, img_original, mask_original, pad_bottom, pad_right = inpaint_preprocess(image, mask_original)

output = session.run([output_name], {input_name_img: img_input, input_name_mask: mask_input})

output_tensor = output[0]

print("Inference completed!")

@QiangJI123
Copy link

@justinchuby hi ,i'm sorry to boring u, but i don't export torch.fft.fft/ifft to onnx , and i see all u anwers ( update torch-nightly & onnx.dynamo_export), but don't work . Could you provide an official statement that the translation is complete?

@justinchuby
Copy link
Collaborator

Can you update onnxscript? What is the version?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Reopened
Development

No branches or pull requests

7 participants
0