8000 Fix workarea compute in lapackSyevd (#146456) · pytorch/pytorch@4d626c2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4d626c2

Browse files
wdvrpytorchmergebot
authored andcommitted
Fix workarea compute in lapackSyevd (#146456)
work-query APIs return floating point values, that could loose precision when converted back to int. Solve this by using `nextafter` and `ceil` Add regression test Fixes #145801 Pull Request resolved: #146456 Approved by: https://github.com/malfet
1 parent 8f07306 commit 4d626c2

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor
256256
lapackSyevd<scalar_t, value_t>(jobz, uplo, n, vectors_data, lda, values_data,
257257
&lwork_query, lwork, &rwork_query, lrwork, &iwork_query, liwork, infos_data);
258258

259-
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(lwork_query));
259+
value_t next_after_lw = std::nextafter(real_impl<scalar_t, value_t>(lwork_query), std::numeric_limits<value_t>::infinity());
260+
lwork = std::max<int>(1, std::ceil(next_after_lw));
261+
260262
Tensor work = at::empty({lwork}, vectors.options());
261263
auto work_data = work.mutable_data_ptr<scalar_t>();
262264

@@ -267,7 +269,8 @@ void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor
267269
Tensor rwork;
268270
value_t* rwork_data = nullptr;
269271
if (vectors.is_complex()) {
270-
lrwork = std::max<int>(1, rwork_query);
272+
value_t next_after_rwork_query = std::nextafter(rwork_query, std::numeric_limits<value_t>::infinity());
273+
lrwork = std::max<int>(1, std::ceil(next_after_rwork_query));
271274
rwork = at::empty({lrwork}, values.options());
272275
rwork_data = rwork.mutable_data_ptr<value_t>();
273276
}

test/test_linalg.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,15 @@ def test_eigvalsh_errors_and_warnings(self, device, dtype):
11191119
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
11201120
torch.linalg.eigvalsh(t, out=out)
11211121

1122+
@onlyCPU
1123+
@skipCPUIfNoLapack
1124+
@dtypes(*floating_and_complex_types())
1125+
def test_eigh_lwork_lapack(self, device, dtype):
1126+
# test that the calculated lwork does not cause a crash, see https://github.com/pytorch/pytorch/issues/145801
1127+
t = torch.rand(3000, 3000, device=device, dtype=dtype)
1128+
y = torch.linalg.eigh(t)
1129+
self.assertEqual(y.eigenvalues.shape, (3000,))
1130+
11221131
@dtypes(*floating_and_complex_types())
11231132
def test_kron(self, device, dtype):
11241133

0 commit comments

Comments
 (0)
0