-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Return only complex from stft #62179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit b6195bb (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
The failing test is |
Yeah, because this is technically BC breaking; previously it was OK to pass |
torch/functional.py
Outdated
@@ -374,18 +374,9 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, | |||
win_length: Optional[int] = None, window: Optional[Tensor] = None, | |||
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False, | |||
onesided: Optional[bool] = None, | |||
return_complex: Optional[bool] = None) -> Tensor: | |||
return_complex: Optional[bool] = True) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return_complex: Optional[bool] = True) -> Tensor: | |
return_complex: bool = True) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM although I wonder if it would be less disruptive to completely get rid of return_complex
in the public API and internally still maintain an optional boolean argument and just simply error out when a value is provided for return_complex
saying that stft will always return complex tensor by default and please get rid of this argument. This way we would only be breaking BC once.
@@ -3978,7 +3978,7 @@ | |||
# missing the `pad_mode` and `center` arguments, which are taken care of at | |||
# `torch.functional.py`. They shall be moved here once we have mapping between | |||
# Python strings and C++ Enum in codegen. | |||
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor | |||
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool return_complex=True) -> Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the declarations also need to be modified at a few other sites for e.g.,
Line 377 in a46d421
return_complex: Optional[bool] = None) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another one:
Line 448 in cf1f594
onesided: Optional[bool] = None, return_complex: Optional[bool] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could find other instances by: grep -rs "bool? return_complex"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, I had to check for istft
which also has return_complex
42743f3
to
3f410af
Compare
@mruberry thoughts? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @mattip this PR mostly looks good to me but can you please exclude the submodule (third_party/nccl/nccl
) update changes in your commit?
3f410af
to
f0e43ed
Compare
Sorry, done |
Synced offline with @peterbell10: We agreed that it'd be better to just make |
Changed out of WIP, I will make another PR for the changes needed to only allow complex input to |
Now mypy is failing: should I rebase? |
7ba1ed6
to
b6195bb
Compare
Rebased to clear merge conflict |
@@ -374,18 +374,9 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, | |||
win_length: Optional[int] = None, window: Optional[Tensor] = None, | |||
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False, | |||
onesided: Optional[bool] = None, | |||
return_complex: Optional[bool] = None) -> Tensor: | |||
return_complex: bool = True) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default needs to be False
here, otherwise this is a silent BC breaking change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or actually, no it still needs to be None
so complex tensors pass through correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand. The intent of this PR is to only return complex Tensors from stft
. So what will having return_complex: Optional[bool] = None
accomplish here, since it cannot be False
later on anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current default is self.is_complex() || (window.defined() && window.is_complex())
, so changing it to true is a silent bc-breaking change. Imagine someone has this code:
stft = torch.stft(...) # doesn't pass any return_complex argument
spectrogram = stft[..., 0]**2 + stft[..., 1]**2
stft
would now return a complex result silently without any warning or error, and the spectrogram calculation is completely wrong now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should make return_complex
a required argument here? So, they would have to explicitly specify if it’s True
or False
. And we can throw a warning to urge users to set it to True
since that would be the default behavior in the future release
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also the case of complex input which already defaults to complex output, so isn't being changed. There's no reason to break that usage.
const bool return_complex = return_complexOpt.value_or( | ||
self.is_complex() || (window.defined() && window.is_complex())); | ||
if (!return_complex) { | ||
if (!return_complexOpt.has_value()) { | ||
TORCH_WARN_ONCE( | ||
"stft will soon require the return_complex parameter be given for real inputs, " | ||
"and will further require that return_complex=True in a future PyTorch release." | ||
); | ||
} | ||
|
||
|
||
// TORCH_WARN_ONCE( | ||
// "stft with return_complex=False is deprecated. In a future pytorch " | ||
// "release, stft will return complex tensors for all inputs, and " | ||
// "return_complex=False will raise an error.\n" | ||
// "Note: you can still call torch.view_as_real on the complex output to " | ||
// "recover the old return format."); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The next step in the deprecation process would be this:
const bool return_complex = return_complexOpt.value_or(
self.is_complex() || (window.defined() && window.is_complex()));
if (!return_complex) {
TORCH_CHECK(
return_complexOpt.has_value(),
"stft requires the return_complex parameter be given for real inputs");
TORCH_WARN_ONCE(
"stft with return_complex=False is deprecated. In a future pytorch "
"release, stft will return complex tensors for all inputs, and "
"return_complex=False will raise an error.\n"
"Note: you can still call torch.view_as_real on the complex output to "
"recover the old return format.");
}
Or if we're okay with breaking BC without warning then TORCH_CHECK(return_complex, ...)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems like a different PR. This one was meant to do something different. I will start over.
Note we're hashing out some BC/FC plans now so BC/FC issues may have to wait a bit |
This turned out to be difficult to do correctly without backward-breaking changes. Closing, maybe someone else can work out a strategy for moving this forward. |
#72882 is on the path to address this issue. |
Fixes the first half of #55948.
The first commit changes
torch.stft
to require (and set as default)return_complex=True
The second part: allow only complex input to
torch.istft
is still TBD. There is some trickiness around the current code, it does something likeso the code that checks the input needs to be carefully rewritten