8000 Updated linalg.lstsq with NumPy compatible kwarg rcond · IvanYashchuk/pytorch@67e45df · GitHub
[go: up one dir, main page]

Skip to content

Commit 67e45df

Browse files
committed
Updated linalg.lstsq with NumPy compatible kwarg rcond
Renamed "cond" -> "rcond" to be NumPy compatible. The default value for rcond was changed to match non-legacy NumPy behavior. ghstack-source-id: 1208544 Pull Request resolved: pytorch#54723
1 parent c3099bc commit 67e45df

File tree

7 files changed

+57
-49
lines changed

7 files changed

+57
-49
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2860,7 +2860,10 @@ struct LapackLstsqHelper {
28602860
}
28612861
return *this;
28622862
}
2863-
self_type& set_rcond(double cond) { this->rcond = static_cast<value_t>(cond); return *this; }
2863+
self_type& set_rcond(double rcond) {
2864+
this->rcond = static_cast<value_t>(rcond);
2865+
return *this;
2866+
}
28642867
self_type& set_rank(Tensor& rank) {
28652868
// only `?gels` is not rank-revealing
28662869
if (LapackLstsqDriverType::Gels != driver_type) {
@@ -2977,7 +2980,7 @@ struct LapackLstsqDriverTypeHash {
29772980
#endif
29782981

29792982
Tensor& _lstsq_helper_cpu(
2980-
Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double cond, std::string driver_name) {
2983+
Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double rcond, std::string driver_name) {
29812984
#ifndef USE_LAPACK
29822985
TORCH_CHECK(false, "torch.linalg.lstsq: LAPACK library not found in compilation");
29832986
#else
@@ -3016,7 +3019,7 @@ Tensor& _lstsq_helper_cpu(
30163019
.set_b(b)
30173020
.set_ldb(std::max<int64_t>(1, std::max(m, n)))
30183021
.set_jpvt()
3019-
.set_rcond(cond)
3022+
.set_rcond(rcond)
30203023
.set_rank(rank)
30213024
.set_s(singular_values)
30223025
.set_infos(infos)
@@ -3308,10 +3311,9 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
33083311
std::string driver_name = get_default_lstsq_driver(driver, input);
33093312

33103313
// set default rcond value
3311-
// TODO: Change this to match non-legacy NumPy behaviour
3312-
double rcond_value = rcond.has_value() && (rcond.value() > 0)
3314+
double rcond_value = rcond.has_value()
33133315
? rcond.value()
3314-
: _get_epsilon(c10::toValueType(input.scalar_type()));
3316+
: _get_epsilon(c10::toValueType(input.scalar_type())) * std::max<int64_t>(input.size(-2), input.size(-1));
33153317

33163318
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt));
33173319

aten/src/ATen/native/cuda/BatchLinearAlgebra.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2735,7 +2735,7 @@ Tensor _lu_solve_helper_cuda(const Tensor& self, const Tensor& LU_data, const Te
27352735
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
27362736

27372737
Tensor& _lstsq_helper_cuda(
2738-
Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double cond, std::string driver_name) {
2738+
Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double rcond, std::string driver_name) {
27392739
#ifndef USE_MAGMA
27402740
TORCH_CHECK(false, "torch.linalg.lstsq: MAGMA library not found in "
27412741
"compilation. Please rebuild with MAGMA.");

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8657,19 +8657,19 @@
86578657
- func: det(Tensor self) -> Tensor
86588658
variants: function, method
86598659

8660-
- func: linalg_lstsq(Tensor self, Tensor b, float? cond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
8660+
- func: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
86618661
python_module: linalg
86628662
variants: function
86638663
dispatch:
86648664
CompositeExplicitAutograd: linalg_lstsq
86658665

8666-
- func: linalg_lstsq.out(Tensor self, Tensor b, float? cond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
8666+
- func: linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
86678667
python_module: linalg
86688668
variants: function
86698669< 9E88 div class="diff-text-inner"> dispatch:
86708670
CPU, CUDA: linalg_lstsq_out
86718671

8672-
- func: _lstsq_helper_(Tensor(a!) self, Tensor(b!) rank, Tensor(c!) singular_values, Tensor(d!) infos, Tensor a, float cond, str driver_name) -> Tensor(a!)
8672+
- func: _lstsq_helper_(Tensor(a!) self, Tensor(b!) rank, Tensor(c!) singular_values, Tensor(d!) infos, Tensor a, float rcond, str driver_name) -> Tensor(a!)
86738673
variants: function
86748674
dispatch:
86758675
CPU: _lstsq_helper_cpu

test/backward_compatibility/check_backward_compatibility.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
("aten::irfft", datetime.date(2021, 1, 31)),
4242
("aten::rfft", datetime.date(2021, 1, 31)),
4343
("aten::_lstsq_helper", datetime.date(9999, 1, 1)),
44+
("aten::linalg_lstsq", datetime.date(2021, 5, 1)),
4445
("aten::_svd_helper", datetime.date(2021, 1, 31)),
4546
("aten::_syevd_helper", datetime.date(9999, 1, 1)),
4647
("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)),

test/test_linalg.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_linalg_lstsq(self, device, dtype):
125125
else:
126126
drivers = ('gels', None)
127127

128-
def check_correctness(a, b, sol):
128+
def check_solution_correctness(a, b, sol):
129129
sol2 = a.pinverse() @ b
130130
self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5)
131131

@@ -196,29 +196,22 @@ def select_if_not_empty(t, i):
196196
self.assertEqual(res.singular_values.shape, (0, ))
197197

198198
def check_correctness_scipy(a, b, res, driver, cond):
199-
if TEST_SCIPY and driver not in (None, 'gels'):
199+
# SciPy provides 3 driver options: gelsd, gelss, gelsy
200+
if TEST_SCIPY and driver in ('gelsd', 'gelss', 'gelsy'):
200201
import scipy.linalg
201202

202203
def scipy_ref(a, b):
203204
return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond)
204205
check_correctness_ref(a, b, res, scipy_ref, driver=driver)
205206

206-
def check_correctness_numpy(a, b, res, driver, cond):
207-
if driver in ('gelsd', 'gelss'):
208-
import numpy.linalg
207+
def check_correctness_numpy(a, b, res, driver, rcond):
208+
# NumPy uses only gelsd routine
209+
if driver == 'gelsd':
209210

210211
def numpy_ref(a, b):
211-
return numpy.linalg.lstsq(a, b, rcond=-1 if cond is None else cond)
212+
return np.linalg.lstsq(a, b, rcond=rcond)
212213
check_correctness_ref(a, b, res, numpy_ref)
213214

214-
def check_ranks(a, ranks, cond=1e-7):
215-
ranks2 = torch.matrix_rank(a, tol=cond)
216-
self.assertEqual(ranks, ranks2)
217-
218-
def check_singular_values(a, sv):
219-
sv2 = a.svd()[1]
220-
self.assertEqual(sv, sv2)
221-
222215
ms = [2 ** i for i in range(5)]
223216
m_ge_n_sizes = [(m, m // 2) for m in ms] + [(m, m) for m in ms]
224217
# cases m < n are only supported on CPU
@@ -229,32 +222,44 @@ def check_singular_values(a, sv):
229222
# that is why we use `cond=1.0`, the mean to cut roughly half of all
230223
# the singular values and compare whether torch.linalg.lstsq agrees with
231224
# SciPy and NumPy.
232-
cond = (None, 1.0)
225+
# if rcond is True then set value for it based on the used algorithm
226+
# rcond == -1 or any other negative value forces LAPACK to use machine precision tolerance
227+
rconds = (None, True, -1)
228+
229+
for batch, matrix_size, driver, rcond in itertools.product(batches, matrix_sizes, drivers, rconds):
230+
# keep the rcond value if it is None or -1, set the driver specific value if it is True
231+
if rcond and rcond != -1:
232+
if driver in ('gelss', 'gelsd'):
233+
# SVD based algorithm; set to zero roughly half of all the singular values
234+
rcond = 1.0
235+
else:
236+
# driver == 'gelsy'
237+
# QR based algorithm; setting the value too high might lead to non-unique solutions and flaky tests
238+
rcond = 1e-4
239+
240+
# specifying rcond value has no effect for gels driver so no need to run the tests again
241+
if driver == 'gels' and rcond is not None:
242+
continue
233243

234-
for batch, matrix_size, driver, cond in itertools.product(batches, matrix_sizes, drivers, cond):
235244
shape = batch + matrix_size
236245
a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
237246
b = torch.rand(*shape, dtype=dtype, device=device)
238247

239-
cond = 1e-7
240248
m = a.size(-2)
241249
n = a.size(-1)
242-
res = torch.linalg.lstsq(a, b, cond=cond, driver=driver)
243-
sol = res.solution.narrow(-2, 0, n)
244-
245-
check_correctness_scipy(a, b, res, driver, cond)
246-
check_correctness_numpy(a, b, res, driver, cond)
247-
248-
check_correctness(a, b, sol)
249-
if self.device_type == 'cpu' and driver != 'gels':
250-
# rank-revealing drivers are only available for the CPU.
251-
# `gels` is not rank-revealing and F987 is only for full
252-
# rank inputs.
253-
check_ranks(a, res.rank, cond)
254-
if self.device_type == 'cpu' and driver in ('gelsd', 'gelss'):
255-
# SVD-based drivers are only available for the CPU.
256-
# These are only `gelsd` and `gelss`.
257-
check_singular_values(a, res.singular_values)
250+
res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
251+
sol = res.solution
252+
253+
# Only checks gelsd, gelss, gelsy drivers
254+
check_correctness_scipy(a, b, res, driver, rcond)
255+
256+
# Only checks gelsd driver
257+
check_correctness_numpy(a, b, res, driver, rcond)
258+
259+
# gels driver is not checked by comparing to NumPy or SciPy implementation
260+
# because NumPy and SciPy do not implement this driver
261+
if driver == 'gels' and rcond is None:
262+
check_solution_correctness(a, b, sol)
258263

259264
@skipCUDAIfNoMagma
260265
@skipCPUIfNoLapack

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@
704704
self: not_implemented("lstsq")
705705
A: not_implemented("lstsq")
706706

707-
- name: linalg_lstsq(Tensor self, Tensor b, float? cond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
707+
- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
708708
self: not_implemented("linalg_lstsq")
709709
b: not_implemented("linalg_lstsq")
710710
output_differentiability: [True, True]

torch/linalg/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@
648648
""")
649649

650650
lstsq = _add_docstr(_linalg.linalg_lstsq, r"""
651-
torch.linalg.lstsq(A, B, cond=None, *, driver=None) -> (Tensor, Tensor, Tensor, Tensor)
651+
torch.linalg.lstsq(A, B, rcond=None, *, driver=None) -> (Tensor, Tensor, Tensor, Tensor)
652652
653653
Computes a solution to the least squares problem of a system of linear equations.
654654
@@ -714,16 +714,16 @@
714714
computations separately.
715715
716716
.. warning::
717-
The default value of :attr:`cond` may change in the future.
717+
The default value of :attr:`rcond` may change in a future PyTorch release.
718718
It is therefore recommended to use a fixed value to avoid potential
719719
breaking changes.
720720
721721
Args:
722722
A (Tensor): lhs tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions.
723723
B (Tensor): rhs tensor of shape `(*, m, k)` where `*` is zero or more batch dimensions.
724-
cond (float, optional): used to determine the effective rank of :attr:`A`.
725-
If :attr:`cond`\ `= None`, :attr:`cond` is set to the machine
726-
precision of the dtype of :attr:`A`. Default: `None`.
724+
rcond (float, optional): used to determine the effective rank of :attr:`A`.
725+
If :attr:`rcond`\ `= None`, :attr:`rcond` is set to the machine
726+
precision of the dtype of :attr:`A` times `max(m, n)`. Default: `None`.
727727
728728
Keyword args:
729729
driver (str, optional): name of the LAPACK/MAGMA method to be used.

0 commit comments

Comments
 (0)
0