-
Notifications
You must be signed in to change notification settings - Fork 24.2k
stft: move towards always returning complex #72882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
For `stft` this makes all cases where `return_complex` default to `False` into an error and adds a warning when `return_complex=False` is passed explicitly. For `istft` this raises an error if the input is not a complex tensor. [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 4e08615 (more details on the Dr. CI page):
🕵️ 12 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
For `stft` this makes all cases where `return_complex` default to `False` into an error and adds a warning when `return_complex=False` is passed explicitly. For `istft` this raises an error if the input is not a complex tensor. [ghstack-poisoned]
For `stft` this makes all cases where `return_complex` default to `False` into an error and adds a warning when `return_complex=False` is passed explicitly. For `istft` this raises an error if the input is not a complex tensor. [ghstack-poisoned]
length=length, return_complex=return_complex) | ||
)SCRIPT"}, | ||
{"stft_0_10", R"SCRIPT( | ||
def stft_0_10(self: Tensor, n_fft: int, hop_length: Optional[int]=None, win_length: Optional[int]=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm getting a major issue with these upgraders. stft
and istft
have python definitions in torch/functional.py
that wrap the ATen operators. So the TorchScript models use prim::CallFunction
instead of aten::stft
and the upgrader is never called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why these functions have to be defined in python? can we move it to C++? It seems like they just handle torch function or call into torch.{stft/istft}
otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tugsbayasgalan should know what to do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We currently iterate over all nodes in the graph and replace the aten operators with upgraders when necessary. I am wondering if we can be little smarter in that pass that would dig deeper when we encounter prim::CallFunction
. I am not sure how feasible that is @eellison @gmagogsfm. Alternatively, fixing the bug @anjali411 mentioned to unblock the moving to C++ might be the way to go. What do you guys think?
…lex" For `stft` this makes all cases where `return_complex` default to `False` into an error and adds a warning when `return_complex=False` is passed explicitly. For `istft` this raises an error if the input is not a complex tensor. [ghstack-poisoned]
@peterbell10 please let me know when this PR is ready for review |
Sure thing. The actual code shouldn't need changed much, but I'm still waiting on #73434 which is blocked by an FC period that ends this week. |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
|
@peterbell10 we should merge this PR right? |
Stack from ghstack (oldest at bottom):
For
stft
this makes all cases wherereturn_complex
default toFalse
into an error and adds a warning whenreturn_complex=False
is passed explicitly.
For
istft
this raises an error if the input is not a complex tensor.