8000 ENH: handle empty matrices in qr decomposition (#11593) · numpy/numpy@8fdc446 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8fdc446

Browse files
convexseteric-wieser
authored andcommitted
ENH: handle empty matrices in qr decomposition (#11593)
Ensure LWORK and LDA respect the requirements of the lapack methods (zgeqrf, dgeqrf, zungqr, dorgqr)
1 parent 9bb569c commit 8fdc446

File tree

3 files changed

+32
-19
lines changed

3 files changed

+32
-19
lines changed

doc/release/1.16.0-notes.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ Even when no elements needed to be drawn, ``np.random.randint`` and
4747
distribution. This has been fixed so that e.g.
4848
``np.random.choice([], 0) == np.array([], dtype=float64)``.
4949

50+
``linalg.qr`` now works with empty matrices
51+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
52+
Previously, a ``LinAlgError`` would be raised when empty matrix
53+
(with zero rows and/or columns) is passed in. This has been fixed
54+
so that outputs of appropriate shapes are returned for the various modes.
55+
5056
ARM support updated
5157
-------------------
5258
Support for ARM CPUs has been updated to accommodate 32 and 64 bit targets,

numpy/linalg/linalg.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -858,13 +858,13 @@ def qr(a, mode='reduced'):
858858

859859
a, wrap = _makearray(a)
860860
_assertRank2(a)
861-
_assertNoEmpty2d(a)
862861
m, n = a.shape
863862
t, result_t = _commonType(a)
864863
a = _fastCopyAndTranspose(t, a)
865864
a = _to_native_byte_order(a)
866865
mn = min(m, n)
867866
tau = zeros((mn,), t)
867+
868868
if isComplexType(t):
869869
lapack_routine = lapack_lite.zgeqrf
870870
routine_name = 'zgeqrf'
@@ -875,14 +875,14 @@ def qr(a, mode='reduced'):
875875
# calculate optimal size of work data 'work'
876876
lwork = 1
877877
work = zeros((lwork,), t)
878-
results = lapack_routine(m, n, a, m, tau, work, -1, 0)
878+
results = lapack_routine(m, n, a, max(1, m), tau, work, -1, 0)
879879
if results['info'] != 0:
880880
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
881881

882882
# do qr decomposition
883-
lwork = int(abs(work[0]))
883+
lwork = max(1, n, int(abs(work[0])))
884884
work = zeros((lwork,), t)
885-
results = lapack_routine(m, n, a, m, tau, work, lwork, 0)
885+
results = lapack_routine(m, n, a, max(1, m), tau, work, lwork, 0)
886886
if results['info'] != 0:
887887
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
888888

@@ -918,14 +918,14 @@ def qr(a, mode='reduced'):
918918
# determine optimal lwork
919919
lwork = 1
920920
work = zeros((lwork,), t)
921-
results = lapack_routine(m, mc, mn, q, m, tau, work, -1, 0)
921+
results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0)
922922
if results['info'] != 0:
923923
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
924924

925925
# compute q
926-
lwork = int(abs(work[0]))
926+
lwork = max(1, n, int(abs(work[0])))
927927
work = zeros((lwork,), t)
928-
results = lapack_routine(m, mc, mn, q, m, tau, work, lwork, 0)
928+
results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0)
929929
if results['info'] != 0:
930930
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
931931

numpy/linalg/tests/test_linalg.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,9 +1582,25 @@ def check_qr(self, a):
15821582
assert_(isinstance(r2, a_type))
15831583
assert_almost_equal(r2, r1)
15841584

1585-
def test_qr_empty(self):
1586-
a = np.zeros((0, 2))
1587-
assert_raises(linalg.LinAlgError, linalg.qr, a)
1585+
1586+
@pytest.mark.parametrize(["m", "n"], [
1587+
(3, 0),
1588+
(0, 3),
1589+
(0, 0)
1590+
])
1591+
def test_qr_empty(self, m, n):
1592+
k = min(m, n)
1593+
a = np.empty((m, n))
1594+
a_type = type(a)
1595+
a_dtype = a.dtype
1596+
1597+
self.check_qr(a)
1598+
1599+
h, tau = np.linalg.qr(a, mode='raw')
1600+
assert_equal(h.dtype, np.double)
1601+
assert_equal(tau.dtype, np.double)
1602+
assert_equal(h.shape, (n, m))
1603+
assert_equal(tau.shape, (k,))
15881604

15891605
def test_mode_raw(self):
15901606
# The factorization is not unique and varies between libraries,
@@ -1625,15 +1641,6 @@ def test_mode_all_but_economic(self):
16251641
self.check_qr(m2)
16261642
self.check_qr(m2.T)
16271643

1628-
def test_0_size(self):
1629-
# There may be good ways to do (some of this) reasonably:
1630-
a = np.zeros((0, 0))
1631-
assert_raises(linalg.LinAlgError, linalg.qr, a)
1632-
a = np.zeros((0, 1))
1633-
assert_raises(linalg.LinAlgError, linalg.qr, a)
1634-
a = np.zeros((1, 0))
1635-
assert_raises(linalg.LinAlgError, linalg.qr, a)
1636-
16371644

16381645
class TestCholesky(object):
16391646
# TODO: are there no other tests for cholesky?

0 commit comments

Comments
 (0)
0