8000 ENH: support for empty matrices in linalg.lstsq · numpy/numpy@cb0fc23 · GitHub
[go: up one dir, main page]

Skip to content

Commit cb0fc23

Browse files
committed
ENH: support for empty matrices in linalg.lstsq
1 parent 977431a commit cb0fc23

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

doc/release/1.16.0-notes.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ New Features
4040
Improvements
4141
============
4242

43+
``linalg.lstsq`` now works with empty matrices
44+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
45+
Previously, a ``LinAlgError`` would be raised when an empty matrix/empty
46+
matrices (with zero rows and/or columns) is passed in. Now outputs of
47+
appropriate shapes are returned.
48+
4349
``randint`` and ``choice`` now work on empty distributions
4450
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4551
Even when no elements needed to be drawn, ``np.random.randint`` and

numpy/linalg/linalg.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2110,7 +2110,6 @@ def lstsq(a, b, rcond="warn"):
21102110
if is_1d:
21112111
b = b[:, newaxis]
21122112
_assertRank2(a, b)
2113-
_assertNoEmpty2d(a, b) # TODO: relax this constraint
21142113
m, n = a.shape[-2:]
21152114
m2, n_rhs = b.shape[-2:]
21162115
if m != m2:
@@ -2141,7 +2140,16 @@ def lstsq(a, b, rcond="warn"):
21412140

21422141
signature = 'DDd->Ddid' if isComplexType(t) else 'ddd->ddid'
21432142
extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq)
2143+
if n_rhs == 0:
2144+
# lapack can't handle n_rhs = 0 - so allocate the array one larger in that axis
2145+
b = zeros(b.shape[:-2] + (m, n_rhs + 1), dtype=b.dtype)
21442146
x, resids, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj)
2147+
if m == 0:
2148+
x[...] = 0
2149+
if n_rhs == 0:
2150+
# remove the item we added
2151+
x = x[..., :n_rhs]
2152+
resids = resids[..., :n_rhs]
21452153

21462154
# remove the axis we added
21472155
if is_1d:

numpy/linalg/tests/test_linalg.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -875,14 +875,12 @@ def test_0_size(self):
875875
class LstsqCases(LinalgSquareTestCase, LinalgNonsquareTestCase):
876876

877877
def do(self, a, b, tags):
878-
if 'size-0' in tags:
879-
assert_raises(LinAlgError, linalg.lstsq, a, b)
880-
return
881-
882878
arr = np.asarray(a)
883879
m, n = arr.shape
884880
u, s, vt = linalg.svd(a, 0)
885881
x, residuals, rank, sv = linalg.lstsq(a, b, rcond=-1)
882+
if m == 0:
883+
assert_((x == 0).all())
886884
if m <= n:
887885
assert_almost_equal(b, dot(a, x))
888886
assert_equal(rank, m)
@@ -923,6 +921,30 @@ def test_future_rcond(self):
923921
# Warning should be raised exactly once (first command)
924922
assert_(len(w) == 1)
925923

924+
@pytest.mark.parametrize(["m", "n", "n_rhs"], [
925+
(4, 2, 2),
926+
(0, 4, 1),
927+
(0, 4, 2),
928+
(4, 0, 1),
929+
(4, 0, 2),
930+
(4, 2, 0),
931+
(0, 0, 0)
932+
])
933+
def test_empty_a_b(self, m, n, n_rhs):
934+
a = np.arange(m * n).reshape(m, n)
935+
b = np.ones((m, n_rhs))
936+
x, residuals, rank, s = linalg.lstsq(a, b, rcond=None)
937+
if m == 0:
938+
assert_( 8057 (x == 0).all())
939+
assert_equal(x.shape, (n, n_rhs))
940+
assert_equal(residuals.shape, ((n_rhs,) if m > n else (0,)))
941+
if m > n and n_rhs > 0:
942+
# residuals are exactly the squared norms of b's columns
943+
r = b - np.dot(a, x)
944+
assert_almost_equal(residuals, (r * r).sum(axis=-2))
945+
assert_equal(rank, min(m, n))
946+
assert_equal(s.shape, (min(m, n),))
947+
926948

927949
class TestMatrixPower(object):
928950
R90 = array([[0, 1], [-1, 0]])

0 commit comments

Comments
 (0)
0