8000 stft: move towards always returning complex · pytorch/pytorch@e129f3f · GitHub
[go: up one dir, main page]

Skip to content

Commit e129f3f

Browse files
committed
stft: move towards always returning complex
For `stft` this makes all cases where `return_complex` default to `False` into an error and adds a warning when `return_complex=False` is passed explicitly. For `istft` this raises an error if the input is not a complex tensor. ghstack-source-id: eb7feb6 Pull Request resolved: #72882
1 parent f43165a commit e129f3f

File tree

12 files changed

+236
-33
lines changed

12 files changed

+236
-33
lines changed

aten/src/ATen/native/SpectralOps.cpp

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -792,20 +792,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
792792
const bool return_complex = return_complexOpt.value_or(
793793
self.is_complex() || (window.defined() && window.is_complex()));
794794
if (!return_complex) {
795-
if (!return_complexOpt.has_value()) {
796-
TORCH_WARN_ONCE(
797-
"stft will soon require the return_complex parameter be given for real inputs, "
798-
"and will further require that return_complex=True in a future PyTorch release."
799-
);
800-
}
795+
TORCH_CHECK(return_complexOpt.has_value(),
796+
"stft requires the return_complex parameter be given for real inputs, "
797+
"and will further require that return_complex=True in a future PyTorch release.");
801798

802799

803-
// TORCH_WARN_ONCE(
804-
// "stft with return_complex=False is deprecated. In a future pytorch "
805-
// "release, stft will return complex tensors for all inputs, and "
806-
// "return_complex=False will raise an error.\n"
807-
// "Note: you can still call torch.view_as_real on the complex output to "
808-
// "recover the old return format.");
800+
TORCH_WARN_ONCE(
801+
"stft with return_complex=False is deprecated. In a future pytorch "
802+
"release, stft will return complex tensors for all inputs, and "
803+
"return_complex=False will raise an error.\n"
804+
"Note: you can still call torch.view_as_real on the complex output to "
805+
"recover the old return format.");
809806
}
810807

811808
if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) {
@@ -945,12 +942,10 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
945942
const auto hop_length = hop_lengthOpt.value_or(n_fft >> 2);
946943
const auto win_length = win_lengthOpt.value_or(n_fft);
947944

948-
if (!self.is_complex()) {
949-
TORCH_WARN_ONCE(
950-
"istft will require a complex-valued input tensor in a future PyTorch release. "
951-
"Matching the output from stft with return_complex=True. ");
952-
}
953-
Tensor input = self.is_complex() ? self.is_conj() ? at::view_as_real(self.resolve_conj()) : at::view_as_real(self) : self;
945+
TORCH_CHECK(self.is_complex(),
946+
"istft requires a complex-valued input tensor matching the "
947+
"output from stft with return_complex=True.");
948+
Tensor input = at::view_as_real(self.resolve_conj());
954949
const auto input_dim = input.dim();
955950
const auto n_frames = input.size(-2);
956951
const auto fft_size = input.size(-3);

caffe2/serialize/versions.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace serialize {
1212
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
1313

1414
#if ENABLE_UPGRADERS
15-
constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL;
15+
constexpr uint64_t kMaxSupportedFileFormatVersion = 11UL;
1616
#else
1717
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
1818
#endif
@@ -83,7 +83,11 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
8383
// Bump the version number to 10 to update aten::gelu and
8484
// and aten::gelu.out to support the new approximate kwarg.
8585
// (see: https://github.com/pytorch/pytorch/pull/61439)
86-
constexpr uint64_t kProducedFileFormatVersion = 0xAL;
86+
// 4) [02/15/2022]
87+
// Bump the version number to 11 to update aten::stft and
88+
// and aten::istft to deprecate real-dtype complex representation
89+
// (see: ###)
90+
constexpr uint64_t kProducedFileFormatVersion = 11UL;
8791
#else
8892
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
8993
#endif
5.04 KB
Binary file not shown.
22.1 KB
Binary file not shown.

test/jit/fixtures_srcs/fixtures_src.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,17 @@ def __init__(self):
5757
def forward(self, x):
5858
out = torch.zeros_like(x)
5959
return torch._C._nn.gelu(x, out=out)
60+
61+
class TestVersionedStftV10(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
65+
def forward(self, x, n_fft: int, window):
66+
return torch.stft(x, n_fft=n_fft, window=window)
67+
68+
class TestVersionedIstftV10(torch.nn.Module):
69+
def __init__(self):
70+
super().__init__()
71+
72+
def forward(self, x, n_fft: int, window):
73+
return torch.istft(x, n_fft=n_fft, window=window)

test/jit/fixtures_srcs/generate_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
9696
TestVersionedLogspaceOutV8(): "aten::logspace.out",
9797
TestVersionedGeluV9(): "aten::gelu",
9898
TestVersionedGeluOutV9(): "aten::gelu.out",
99+
TestVersionedStftV10(): "aten::stft",
100+
TestVersionedIstftV10(): "aten::istft",
99101
}
100102

101103
"""

test/jit/test_save_load_for_op_version.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,3 +540,39 @@ def forward(self, a: Union[int, float, complex], b: Union[int, float, complex],
540540
self.assertTrue(output.size(dim=0) == 100)
541541
# "Upgraded" model should match the new version output
542542
self.assertEqual(output, output_current)
543+
544+
def test_versioned_stft(self):
545+
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_stft_v10.ptl"
546+
loaded_model = torch.jit.load(model_path)
547+
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
548+
buffer.seek(0)
549+
v10_mobile_module = _load_for_lite_interpreter(buffer)
550+
551+
for in_dtype, window_dtype in product(
552+
[torch.float32, torch.complex64], repeat=2):
553+
input = torch.rand((100,), dtype=in_dtype)
554+
window = torch.rand((10,), dtype=window_dtype)
555+
output = v10_mobile_module(input, 10, window)
556+
output_current = torch.stft(input, n_fft=10, window=window, return_complex=True)
557+
558+
if input.is_complex() or window.is_complex():
559+
self.assertEqual(output, output_current)
560+
else:
561+
self.assertEqual(torch.view_as_complex(output), output_current)
562+
563+
def test_versioned_istft(self):
564+
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_istft_v10.ptl"
565+
loaded_model = torch.jit.load(model_path)
566+
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
567+
buffer.seek(0)
568+
v10_mobile_module = _load_for_lite_interpreter(buffer)
569+
570+
for real_input in [True, False]:
571+
input = torch.rand((10, 10,), dtype=torch.complex64)
572+
old_input = torch.view_as_real(input) if real_input else input
573+
574+
window = torch.rand((10,))
575+
output = v10_mobile_module(old_input, 10, window)
576+
output_current = torch.istft(input, n_fft=10, window=window)
577+
578+
self.assertEqual(output, output_current)

test/test_spectral_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,9 +1117,8 @@ def test_complex_stft_onesided(self, device):
11171117
@skipCPUIfNoFFT
11181118
def test_stft_requires_complex(self, device):
11191119
x = torch.rand(100)
1120-
y = x.stft(10, pad_mode='constant')
1121-
# with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
1122-
# y = x.stft(10, pad_mode='constant')
1120+
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
1121+
y = x.stft(10, pad_mode='constant')
11231122

11241123
@skipCPUIfNoFFT
11251124
def test_fft_input_modification(self, device):

torch/csrc/jit/mobile/upgrader_mobile.cpp

Lines changed: 120 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,29 @@ getOperatorVersionMapForMobile() {
5151
std::vector<Upgrader>({
5252
Upgrader({0, 9, "gelu_out_0_9", 6})
5353
})},
54+
{std::string("aten::istft"),
55+
std::vector<Upgrader>({
56+
Upgrader({0, 10, "istft_0_10", 7})
57+
})},
5458
{std::string("aten::linspace"),
5559
std::vector<Upgrader>({
56-
Upgrader({0, 7, "linspace_0_7", 7})
60+
Upgrader({0, 7, "linspace_0_7", 8})
5761
})},
5862
{std::string("aten::linspace.out"),
5963
std::vector<Upgrader>({
60-
Upgrader({0, 7, "linspace_out_0_7", 8})
64+
Upgrader({0, 7, "linspace_out_0_7", 9})
6165
})},
6266
{std::string("aten::logspace"),
6367
std::vector<Upgrader>({
64-
Upgrader({0, 8, "logspace_0_8", 9})
68+
Upgrader({0, 8, "logspace_0_8", 10})
6569
})},
6670
{std::string("aten::logspace.out"),
6771
std::vector<Upgrader>({
68-
Upgrader({0, 8, "logspace_out_0_8", 10})
72+
Upgrader({0, 8, "logspace_out_0_8", 11})
73+
})},
74+
{std::string("aten::stft"),
75+
std::vector<Upgrader>({
76+
Upgrader({0, 10, "stft_0_10", 12})
6977
})},
7078
});
7179
return operatorVersionMapForMobile;
@@ -339,6 +347,48 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
339347
OperatorString({"aten::gelu", "out", 2}),
340348
}), // operators list
341349
}),
350+
ByteCodeFunctionWithOperator({
351+
mobile::Function::registerFunc(
352+
"istft_0_10",
353+
std::vector<Instruction>({
354+
Instruction{OpCode::STOREN, 1, 10},
355+
Instruction{OpCode::LOAD, 1, 0},
356+
Instruction{OpCode::OP, 0, 0},
357+
Instruction{OpCode::__NOT__, 0, 0},
358+
Instruction{OpCode::JF, 5, 0},
359+
Instruction{OpCode::LOAD, 1, 0},
360+
Instruction{OpCode::OP, 1, 0},
361+
Instruction{OpCode::OP, 2, 0},
362+
Instruction{OpCode::JMP, 2, 0},
363+
Instruction{OpCode::LOAD, 1, 0},
364+
Instruction{OpCode::STORE, 11, 0},
365+
Instruction{OpCode::DROPR, 1, 0},
366+
Instruction{OpCode::MOVE, 11, 0},
367+
Instruction{OpCode::MOVE, 2, 0},
368+
Instruction{OpCode::MOVE, 3, 0},
369+
Instruction{OpCode::MOVE, 4, 0},
370+
Instruction{OpCode::MOVE, 5, 0},
371+
Instruction{OpCode::MOVE, 6, 0},
372+
Instruction{OpCode::MOVE, 7, 0},
373+
Instruction{OpCode::MOVE, 8, 0},
374+
Instruction{OpCode::MOVE, 9, 0},
375+
Instruction{OpCode::MOVE, 10, 0},
376+
Instruction{OpCode::OP, 3, 0},
377+
Instruction{OpCode::RET, 0, 0},
378+
}), // instructions list,
379+
std::vector<c10::IValue>({
380+
c10::IValue(0),
381+
}), // constants list,
382+
std::vector<c10::TypePtr>(), // types list,
383+
11
384+
),
385+
std::vector<OperatorString>({
386+
OperatorString({"aten::is_complex", "", 1}),
387+
OperatorString({"aten::contiguous", "", 1}),
388+
OperatorString({"aten::view_as_complex", "", 1}),
389+
OperatorString({"aten::istft", "", 10}),
390+
}), // operators list
391+
}),
342392
ByteCodeFunctionWithOperator({
343393
mobile::Function::registerFunc(
344394
"linspace_0_7",
@@ -527,6 +577,72 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
527577
OperatorString({"prim::unchecked_cast", "", 1}),
528578
}), // operators list
529579
}),
580+
ByteCodeFunctionWithOperator({
581+
mobile::Function::registerFunc(
582+
"stft_0_10",
583+
std::vector<Instruction>({
584+
Instruction{OpCode::STOREN, 1, 8},
585+
Instruction{OpCode::LOAD, 8, 0},
586+
Instruction{OpCode::LOADC, 2, 0},
587+
Instruction{OpCode::__IS__, 0, 0},
588+
Instruction{OpCode::JF, 27, 0},
589+
Instruction{OpCode::LOAD, 1, 0},
590+
Instruction{OpCode::OP, 0, 0},
591+
Instruction{OpCode::JF, 4, 0},
592+
Instruction{OpCode::LOAD, 5, 0},
593+
Instruction{OpCode::LOADC, 1, 0},
594+
Instruction{OpCode::JMP, 17, 0},
595+
Instruction{OpCode::LOAD, 5, 0},
596+
Instruction{OpCode::LOADC, 2, 0},
597+
Instruction{OpCode::__ISNOT__, 0, 0},
598+
Instruction{OpCode::JF, 8, 0},
599+
Instruction{OpCode::LOAD, 5, 0},
600+
Instruction{OpCode::OP, 1, 0},
601+
Instruction{OpCode::STORE, 9, 0},
602+
Instruction{OpCode::LOAD, 9, 0},
603+
Instruction{OpCode::MOVE, 9, 0},
604+
Instruction{OpCode::OP, 0, 0},
605+
Instruction{OpCode::JMP, 3, 0},
606+
Instruction{OpCode::LOAD, 5, 0},
607+
Instruction{OpCode::LOADC, 0, 0},
608+
Instruction{OpCode::STOREN, 10, 2},
609+
Instruction{OpCode::MOVE, 10, 0},
610+
Instruction{OpCode::MOVE, 11, 0},
611+
Instruction{OpCode::STOREN, 12, 2},
612+
Instruction{OpCode::MOVE, 12, 0},
613+
Instruction{OpCode::MOVE, 13, 0},
614+
Instruction{OpCode::JMP, 4, 0},
615+
Instruction{OpCode::LOAD, 5, 0},
616+
Instruction{OpCode::LOAD, 8, 0},
617+
Instruction{OpCode::OP, 1, 0},
618+
Instruction{OpCode::STOREN, 14, 2},
619+
Instruction{OpCode::DROPR, 5, 0},
620+
Instruction{OpCode::DROPR, 8, 0},
621+
Instruction{OpCode::MOVE, 1, 0},
622+
Instruction{OpCode::MOVE, 2, 0},
623+
Instruction{OpCode::MOVE, 3, 0},
624+
Instruction{OpCode::MOVE, 4, 0},
625+
Instruction{OpCode::MOVE, 14, 0},
626+
Instruction{OpCode::MOVE, 6, 0},
627+
Instruction{OpCode::MOVE, 7, 0},
628+
Instruction{OpCode::MOVE, 15, 0},
629+
Instruction{OpCode::OP, 2, 0},
630+
Instruction{OpCode::RET, 0, 0},
631+
}), // instructions list,
632+
std::vector<c10::IValue>({
633+
c10::IValue(false),
634+
c10::IValue(true),
635+
c10::IValue(),
636+
}), // constants list,
637+
std::vector<c10::TypePtr>(), // types list,
638+
15
639+
),
640+
std::vector<OperatorString>({
641+
OperatorString({"aten::is_complex", "", 1}),
642+
OperatorString({"prim::unchecked_cast", "", 1}),
643+
OperatorString({"aten::stft", "", 8}),
644+
}), // operators list
645+
}),
530646
});
531647
for (const auto& upgrader_function : upgrader_function_list) {
532648
for (const auto& op : upgrader_function.operators) {

torch/csrc/jit/operator_upgraders/upgraders_entry.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,27 @@ namespace torch {
1515
namespace jit {
1616

1717
static std::unordered_map<std::string, std::string> kUpgradersEntryMap({
18+
{"istft_0_10", R"SCRIPT(
19+
def istft_0_10(self: Tensor, n_fft: int, hop_length: Optional[int] = None,
20+
win_length: Optional[int] = None, window: Optional[Tensor] = None,
21+
center: bool = True, normalized: bool = False,
22+
onesided: Optional[bool] = None, length: Optional[int] = None,
23+
return_complex: bool = False) -> Tensor:
24+
if not self.is_complex():
25+
self = torch.view_as_complex(self.contiguous())
26+
return torch.istft(self, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
27+
window=window, center=center, normalized=normalized, onesided=onesided,
28+
length=length, return_complex=return_complex)
29+
)SCRIPT"},
30+
{"stft_0_10", R"SCRIPT(
31+
def stft_0_10(self: Tensor, n_fft: int, hop_length: Optional[int]=None, win_length: Optional[int]=None,
32+
window: Optional[Tensor]=None, normalized: bool=False, onesided: Optional[bool]=None,
33+
return_complex: Optional[bool]=None) -> Tensor:
34+
if return_complex is None:
35+
return_complex = self.is_complex() or (window is not None and window.is_complex())
36+
return torch.stft(self, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window,
37+
normalized=normalized, onesided=onesided, return_complex=return_complex)
38+
)SCRIPT"},
1839
{"logspace_0_8", R"SCRIPT(
1940
def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int],
2041
device: Optional[Device], pin_memory: Optional[bool]):

torch/csrc/jit/operator_upgraders/version_map.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@ static bool isVersionMapSorted = false;
1616
// Note for developers: The list of upgraders should be SORTED
1717
// by the version number where the upgrader is registered.
1818
static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
19-
{{"aten::logspace",
19+
{
20+
{"aten::istft",
21+
{{11,
22+
"istft_0_10",
23+
"aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor"}}},
24+
{"aten::stft",
25+
{{11,
26+
"stft_0_10",
27+
"aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"}}},
28+
{"aten::logspace",
2029
{{9,
2130
"logspace_0_8",
2231
"aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},

0 commit comments

Comments
 (0)
0