8000 Merge pull request #8368 from eric-wieser/0x0-linalg · numpy/numpy@1532532 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1532532

Browse files
authored
Merge pull request #8368 from eric-wieser/0x0-linalg
ENH: Implement most linalg operations for 0x0 matrices
2 parents 96c3e66 + 6627449 commit 1532532

File tree

3 files changed

+40
-29
lines changed

3 files changed

+40
-29
lines changed

doc/release/1.13.0-notes.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ np.matrix with booleans elements can now be created using the string syntax
194194
``np.matrix`` failed whenever one attempts to use it with booleans, e.g.,
195195
``np.matrix('True')``. Now, this works as expected.
196196

197+
More ``linalg`` operations now accept empty vectors and matrices
198+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
199+
All of the following functions in ``np.linalg`` now work when given input
200+
arrays with a 0 in the last two dimensions: `det``, ``slogdet``, ``pinv``,
201+
``eigvals``, ``eigvalsh``, ``eig``, ``eigh``.
202+
197203
Changes
198204
=======
199205

numpy/linalg/linalg.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
2424
add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
2525
finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs,
26-
broadcast, atleast_2d, intp, asanyarray, isscalar, object_
26+
broadcast, atleast_2d, intp, asanyarray, isscalar, object_, ones
2727
)
2828
from numpy.core.multiarray import normalize_axis_index
2929
from numpy.lib import triu, asfarray
@@ -217,9 +217,13 @@ def _assertFinite(*arrays):
217217
if not (isfinite(a).all()):
218218
raise LinAlgError("Array must not contain infs or NaNs")
219219

220+
def _isEmpty2d(arr):
221+
# check size first for efficiency
222+
return arr.size == 0 and product(arr.shape[-2:]) == 0
223+
220224
def _assertNoEmpty2d(*arrays):
221225
for a in arrays:
222-
if a.size == 0 and product(a.shape[-2:]) == 0:
226+
if _isEmpty2d(a):
223227
raise LinAlgError("Arrays cannot be empty")
224228

225229

@@ -898,11 +902,12 @@ def eigvals(a):
898902
899903
"""
900904
a, wrap = _makearray(a)
901-
_assertNoEmpty2d(a)
902905
_assertRankAtLeast2(a)
903906
_assertNdSquareness(a)
904907
_assertFinite(a)
905908
t, result_t = _commonType(a)
909+
if _isEmpty2d(a):
910+
return empty(a.shape[-1:], dtype=result_t)
906911

907912
extobj = get_linalg_error_extobj(
908913
_raise_linalgerror_eigenvalues_nonconvergence)
@@ -1002,10 +1007,11 @@ def eigvalsh(a, UPLO='L'):
10021007
gufunc = _umath_linalg.eigvalsh_up
10031008

10041009
a, wrap = _makearray(a)
1005-
_assertNoEmpty2d(a)
10061010
_assertRankAtLeast2(a)
10071011
_assertNdSquareness(a)
10081012
t, result_t = _commonType(a)
1013+
if _isEmpty2d(a):
1014+
return empty(a.shape[-1:], dtype=result_t)
10091015
signature = 'D->d' if isComplexType(t) else 'd->d'
10101016
w = gufunc(a, signature=signature, extobj=extobj)
10111017
return w.astype(_realType(result_t), copy=False)
@@ -1139,11 +1145,14 @@ def eig(a):
11391145
11401146
"""
11411147
a, wrap = _makearray(a)
1142-
_assertNoEmpty2d(a)
11431148
_assertRankAtLeast2(a)
11441149
_assertNdSquareness(a)
11451150
_assertFinite(a)
11461151
t, result_t = _commonType(a)
1152+
if _isEmpty2d(a):
1153+
w = empty(a.shape[-1:], dtype=result_t)
1154+
vt = empty(a.shape, dtype=result_t)
1155+
return w, wrap(vt)
11471156

11481157
extobj = get_linalg_error_extobj(
11491158
_raise_linalgerror_eigenvalues_nonconvergence)
@@ -1280,8 +1289,11 @@ def eigh(a, UPLO='L'):
12801289
a, wrap = _makearray(a)
12811290
_assertRankAtLeast2(a)
12821291
_assertNdSquareness(a)
1283-
_assertNoEmpty2d(a)
12841292
t, result_t = _commonType(a)
1293+
if _isEmpty2d(a):
1294+
w = empty(a.shape[-1:], dtype=result_t)
1295+
vt = empty(a.shape, dtype=result_t)
1296+
return w, wrap(vt)
12851297

12861298
extobj = get_linalg_error_extobj(
12871299
_raise_linalgerror_eigenvalues_nonconvergence)
@@ -1660,7 +1672,9 @@ def pinv(a, rcond=1e-15 ):
16601672
16611673
"""
16621674
a, wrap = _makearray(a)
1663-
_assertNoEmpty2d(a)
1675+
if _isEmpty2d(a):
1676+
res = empty(a.shape[:-2] + (a.shape[-1], a.shape[-2]), dtype=a.dtype)
1677+
return wrap(res)
16641678
a = a.conjugate()
16651679
u, s, vt = svd(a, 0)
16661680
m = u.shape[0]
@@ -1751,11 +1765,15 @@ def slogdet(a):
17511765
17521766
"""
17531767
a = asarray(a)
1754-
_assertNoEmpty2d(a)
17551768
_assertRankAtLeast2(a)
17561769
_assertNdSquareness(a)
17571770
t, result_t = _commonType(a)
17581771
real_t = _realType(result_t)
1772+
if _isEmpty2d(a):
1773+
# determinant of empty matrix is 1
1774+
sign = ones(a.shape[:-2], dtype=result_t)
1775+
logdet = zeros(a.shape[:-2], dtype=real_t)
1776+
return sign, logdet
17591777
signature = 'D->Dd' if isComplexType(t) else 'd->dd'
17601778
sign, logdet = _umath_linalg.slogdet(a, signature=signature)
17611779
if isscalar(sign):
@@ -1816,10 +1834,12 @@ def det(a):
18161834
18171835
"""
18181836
a = asarray(a)
1819-
_assertNoEmpty2d(a)
18201837
_assertRankAtLeast2(a)
18211838
_assertNdSquareness(a)
18221839
t, result_t = _commonType(a)
1840+
# 0x0 matrices have determinant 1
1841+
if _isEmpty2d(a):
1842+
return ones(a.shape[:-2], dtype=result_t)
18231843
signature = 'D->D' if isComplexType(t) else 'd->d'
18241844
r = _umath_linalg.det(a, signature=signature)
18251845
if isscalar(r):

numpy/linalg/tests/test_linalg.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ def apply_tag(tag, cases):
127127
array([[2. + 1j, 1. + 2j, 1 + 3j], [1 - 2j, 1 - 3j, 1 - 6j]], dtype=cdouble)),
128128
LinalgCase("0x0",
129129
np.empty((0, 0), dtype=double),
130-
np.empty((0, 0), dtype=double),
130+
np.empty((0,), dtype=double),
131+
tags={'size-0'}),
132+
LinalgCase("0x0_matrix",
133+
np.empty((0, 0), dtype=double).view(np.matrix),
134+
np.empty((0, 1), dtype=double).view(np.matrix),
131135
tags={'size-0'}),
132136
LinalgCase("8x8",
133137
np.random.rand(8, 8),
@@ -549,9 +553,6 @@ class ArraySubclass(np.ndarray):
549553
class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
550554

551555
def do(self, a, b, tags):
552-
if 'size-0' in tags:
553-
assert_raises(LinAlgError, linalg.eigvals, a)
554-
return
555556
ev = linalg.eigvals(a)
556557
evalues, evectors = linalg.eig(a)
557558
assert_almost_equal(ev, evalues)
@@ -569,10 +570,6 @@ def check(dtype):
569570
class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
570571

571572
def do(self, a, b, tags):
572-
if 'size-0' in tags:
573-
assert_raises(LinAlgError, linalg.eig, a)
574-
return
575-
576573
evalues, evectors = linalg.eig(a)
577574
assert_allclose(dot_generalized(a, evectors),
578575
np.asarray(evectors) * np.asarray(evalues)[..., None, :],
@@ -667,9 +664,6 @@ def test(self):
667664
class TestPinv(LinalgSquareTestCase, LinalgNonsquareTestCase):
668665

669666
def do(self, a, b, tags):
670-
if 'size-0' in tags:
671-
assert_raises(LinAlgError, linalg.pinv, a)
672-
return
673667
a_ginv = linalg.pinv(a)
674668
# `a @ a_ginv == I` does not hold if a is singular
675669
assert_almost_equal(dot(a, a_ginv).dot(a), a, single_decimal=5, double_decimal=11)
@@ -679,9 +673,6 @@ def do(self, a, b, tags):
679673
class TestDet(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
680674

681675
def do(self, a, b, tags):
682-
if 'size-0' in tags:
683-
assert_raises(LinAlgError, linalg.det, a)
684-
return
685676
d = linalg.det(a)
686677
(s, ld) = linalg.slogdet(a)
687678
if asarray(a).dtype.type in (single, double):
@@ -820,9 +811,6 @@ def test_square(self):
820811
class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase):
821812

822813
def do(self, a, b, tags):
823-
if 'size-0' in tags:
824-
assert_raises(LinAlgError, linalg.eigvalsh, a, 'L')
825-
return
826814
# note that eigenvalue arrays returned by eig must be sorted since
827815
# their order isn't guaranteed.
828816
ev = linalg.eigvalsh(a, 'L')
@@ -873,9 +861,6 @@ def test_UPLO(self):
873861
class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
874862

875863
def do(self, a, b, tags):
876-
if 'size-0' in tags:
877-
assert_raises(LinAlgError, linalg.eigh, a)
878-
return
879864
# note that eigenvalue arrays returned by eig must be sorted since
880865
# their order isn't guaranteed.
881866
ev, evc = linalg.eigh(a)

0 commit comments

Comments
 (0)
0