8000 stft: move towards always returning complex by peterbell10 · Pull Request #72882 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

stft: move towards always returning complex #72882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,20 +791,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
const bool return_complex = return_complexOpt.value_or(
self.is_complex() || (window.defined() && window.is_complex()));
if (!return_complex) {
if (!return_complexOpt.has_value()) {
TORCH_WARN_ONCE(
"stft will soon require the return_complex parameter be given for real inputs, "
"and will further require that return_complex=True in a future PyTorch release."
);
}
TORCH_CHECK(return_complexOpt.has_value(),
"stft requires the return_complex parameter be given for real inputs, "
"and will further require that return_complex=True in a future PyTorch release.");


// TORCH_WARN_ONCE(
// "stft with return_complex=False is deprecated. In a future pytorch "
// "release, stft will return complex tensors for all inputs, and "
// "return_complex=False will raise an error.\n"
// "Note: you can still call torch.view_as_real on the complex output to "
// "recover the old return format.");
TORCH_WARN_ONCE(
"stft with return_complex=False is deprecated. In a future pytorch "
"release, stft will return complex tensors for all inputs, and "
"return_complex=False will raise an error.\n"
"Note: you can still call torch.view_as_real on the complex output to "
"recover the old return format.");
}

if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) {
Expand Down Expand Up @@ -968,12 +965,10 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
const auto hop_length = hop_lengthOpt.value_or(n_fft >> 2);
const auto win_length = win_lengthOpt.value_or(n_fft);

if (!self.is_complex()) {
TORCH_WARN_ONCE(
"istft will require a complex-valued input tensor in a future PyTorch release. "
"Matching the output from stft with return_complex=True. ");
}
Tensor input = self.is_complex() ? self.is_conj() ? at::view_as_real(self.resolve_conj()) : at::view_as_real(self) : self;
TORCH_CHECK(self.is_complex(),
"istft requires a complex-valued input tensor matching the "
"output from stft with return_complex=True.");
Tensor input = at::view_as_real(self.resolve_conj());
const auto input_dim = input.dim();
const auto n_frames = input.size(-2);
const auto fft_size = input.size(-3);
Expand Down
8 changes: 6 additions & 2 deletions caffe2/serialize/versions.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace serialize {
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;

#if ENABLE_UPGRADERS
constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL;
constexpr uint64_t kMaxSupportedFileFormatVersion = 11UL;
#else
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
#endif
Expand Down Expand Up @@ -83,7 +83,11 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
// Bump the version number to 10 to update aten::gelu and
// and aten::gelu.out to support the new approximate kwarg.
// (see: https://github.com/pytorch/pytorch/pull/61439)
constexpr uint64_t kProducedFileFormatVersion = 0xAL;
// 4) [02/15/2022]
// Bump the version number to 11 to update aten::stft and
// and aten::istft to deprecate real-dtype complex representation
// (see: ###)
constexpr uint64_t kProducedFileFormatVersion = 11UL;
#else
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
#endif
Expand Down
Binary file added test/jit/fixtures/test_versioned_istft_v10.ptl
Binary file not shown.
Binary file added test/jit/fixtures/test_versioned_stft_v10.ptl
Binary file not shown.
14 changes: 14 additions & 0 deletions test/jit/fixtures_srcs/fixtures_src.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,17 @@ def __init__(self):
def forward(self, x):
out = torch.zeros_like(x)
return torch._C._nn.gelu(x, out=out)

class TestVersionedStftV10(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, n_fft: int, window):
return torch.stft(x, n_fft=n_fft, window=window)

class TestVersionedIstftV10(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, n_fft: int, window):
return torch.istft(x, n_fft=n_fft, window=window)
2 changes: 2 additions & 0 deletions test/jit/fixtures_srcs/generate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
TestVersionedLogspaceOutV8(): "aten::logspace.out",
TestVersionedGeluV9(): "aten::gelu",
TestVersionedGeluOutV9(): "aten::gelu.out",
TestVersionedStftV10(): "aten::stft",
TestVersionedIstftV10(): "aten::istft",
}

"""
Expand Down
36 changes: 36 additions & 0 deletions test/jit/test_save_load_for_op_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,39 @@ def forward(self, a: Union[int, float, complex], b: Union[int, float, complex],
self.assertTrue(output.size(dim=0) == 100)
# "Upgraded" model should match the new version output
self.assertEqual(output, output_current)

def test_versioned_stft(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_stft_v10.ptl"
loaded_model = torch.jit.load(model_path)
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
v10_mobile_module = _load_for_lite_interpreter(buffer)

for in_dtype, window_dtype in product(
[torch.float32, torch.complex64], repeat=2):
input = torch.rand((100,), dtype=in_dtype)
window = torch.rand((10,), dtype=window_dtype)
output = v10_mobile_module(input, 10, window)
output_current = torch.stft(input, n_fft=10, window=window, return_complex=True)

if input.is_complex() or window.is_complex():
self.assertEqual(output, output_current)
else:
self.assertEqual(torch.view_as_complex(output), output_current)

def test_versioned_istft(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_istft_v10.ptl"
loaded_model = torch.jit.load(model_path)
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
v10_mobile_module = _load_for_lite_interpreter(buffer)

for real_input in [True, False]:
input = torch.rand((10, 10,), dtype=torch.complex64)
old_input = torch.view_as_real(input) if real_input else input

window = torch.rand((10,))
output = v10_mobile_module(old_input, 10, window)
output_current = torch.istft(input, n_fft=10, window=window)

self.assertEqual(output, output_current)
52 changes: 29 additions & 23 deletions test/test_spectral_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@
not has_scipy_fft or LooseVersion(scipy.__version__) >= '1.6.0')
else (None, "ortho"))

def _complex_from_float_dtype(real_dtype):
return {
torch.float16: torch.complex32,
torch.float32: torch.complex64,
torch.float64: torch.complex128
}[real_dtype]



def _complex_stft(x, *args, **kwargs):
# Transform real and imaginary components separably
Expand Down Expand Up @@ -738,11 +746,7 @@ def test_fftshift_frequencies(self, device, dtype):

# Legacy fft tests
def _test_fft_ifft_rfft_irfft(self, device, dtype):
complex_dtype = {
torch.float16: torch.complex32,
torch.float32: torch.complex64,
torch.float64: torch.complex128
}[dtype]
complex_dtype = _complex_from_float_dtype(dtype)

def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device))
Expand Down Expand Up @@ -1136,9 +1140,8 @@ def test_complex_stft_onesided(self, device):
@skipCPUIfNoFFT
def test_stft_requires_complex(self, device):
x = torch.rand(100)
y = x.stft(10, pad_mode='constant')
# with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
# y = x.stft(10, pad_mode='constant')
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
y = x.stft(10, pad_mode='constant')

@skipCPUIfNoFFT
def test_fft_input_modification(self, device):
Expand Down Expand Up @@ -1346,20 +1349,22 @@ def test_istft_throws(self, device):
@skipCPUIfNoFFT
@dtypes(torch.double)
def test_istft_of_sine(self, device, dtype):
complex_dtype = _complex_from_float_dtype(dtype)

def _test(amplitude, L, n):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
x = torch.arange(2 * L + 1, device=device, dtype=dtype)
original = amplitude * torch.sin(2 * math.pi / L * x * n)
# stft = torch.stft(original, L, hop_length=L, win_length=L,
# window=torch.ones(L), center=False, normalized=False)
stft = torch.zeros((L // 2 + 1, 2, 2), device=device, dtype=dtype)
stft = torch.zeros((L // 2 + 1, 2), device=device, dtype=complex_dtype)
stft_largest_val = (amplitude * L) / 2.0
if n < stft.size(0):
stft[n, :, 1] = -stft_largest_val
stft[n].imag = torch.tensor(-stft_largest_val, dtype=dtype)

if 0 <= L - n < stft.size(0):
# symmetric about L // 2
stft[L - n, :, 1] = stft_largest_val
stft[L - n].imag = torch.tensor(stft_largest_val, dtype=dtype)

inverse = torch.istft(
stft, L, hop_length=L, win_length=L,
Expand All @@ -1381,11 +1386,12 @@ def _test(amplitude, L, n):
@dtypes(torch.double)
def test_istft_linearity(self, device, dtype):
num_trials = 100
complex_dtype = _complex_from_float_dtype(dtype)

def _test(data_size, kwargs):
for i in range(num_trials):
tensor1 = torch.randn(data_size, device=device, dtype=dtype)
tensor2 = torch.randn(data_size, device=device, dtype=dtype)
tensor1 = torch.randn(data_size, device=device, dtype=complex_dtype)
tensor2 = torch.randn(data_size, device=device, dtype=complex_dtype)
a, b = torch.rand(2, dtype=dtype, device=device)
# Also compare method vs. functional call signature
istft1 = tensor1.istft(**kwargs)
Expand All @@ -1396,7 +1402,7 @@ def _test(data_size, kwargs):
patterns = [
# hann_window, centered, normalized, onesided
(
(2, 7, 7, 2),
(2, 7, 7),
{
'n_fft': 12,
'window': torch.hann_window(12, device=device, dtype=dtype),
Expand All @@ -1407,7 +1413,7 @@ def _test(data_size, kwargs):
),
# hann_window, centered, not normalized, not onesided
(
(2, 12, 7, 2),
(2, 12, 7),
{
'n_fft': 12,
'window': torch.hann_window(12, device=device, dtype=dtype),
Expand All @@ -1418,7 +1424,7 @@ def _test(data_size, kwargs):
),
# hamming_window, centered, normalized, not onesided
(
(2, 12, 7, 2),
(2, 12, 7),
{
'n_fft': 12,
'window': torch.hamming_window(12, device=device, dtype=dtype),
Expand All @@ -1429,7 +1435,7 @@ def _test(data_size, kwargs):
),
# hamming_window, not centered, not normalized, onesided
(
(2, 7, 3, 2),
(2, 7, 3),
{
'n_fft': 12,
'window': torch.hamming_window(12, device=device, dtype=dtype),
Expand All @@ -1446,13 +1452,13 @@ def _test(data_size, kwargs):
@skipCPUIfNoFFT
def test_batch_istft(self, device):
original = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
], device=device)
[4., 4., 4., 4., 4.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]
], device=device, dtype=torch.complex64)

single = original.repeat(1, 1, 1, 1)
multi = original.repeat(4, 1, 1, 1)
single = original.repeat(1, 1, 1)
multi = original.repeat(4, 1, 1)

i_original = torch.istft(original, n_fft=4, length=4)
i_single = torch.istft(single, n_fft=4, length=4)
Expand Down
Loading
0