|
23 | 23 | csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
|
24 | 24 | add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
|
25 | 25 | finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs,
|
26 |
| - broadcast, atleast_2d, intp, asanyarray, isscalar, object_ |
| 26 | + broadcast, atleast_2d, intp, asanyarray, isscalar, object_, ones |
27 | 27 | )
|
28 | 28 | from numpy.core.multiarray import normalize_axis_index
|
29 | 29 | from numpy.lib import triu, asfarray
|
@@ -217,9 +217,13 @@ def _assertFinite(*arrays):
|
217 | 217 | if not (isfinite(a).all()):
|
218 | 218 | raise LinAlgError("Array must not contain infs or NaNs")
|
219 | 219 |
|
| 220 | +def _isEmpty2d(arr): |
| 221 | + # check size first for efficiency |
| 222 | + return arr.size == 0 and product(arr.shape[-2:]) == 0 |
| 223 | + |
220 | 224 | def _assertNoEmpty2d(*arrays):
|
221 | 225 | for a in arrays:
|
222 |
| - if a.size == 0 and product(a.shape[-2:]) == 0: |
| 226 | + if _isEmpty2d(a): |
223 | 227 | raise LinAlgError("Arrays cannot be empty")
|
224 | 228 |
|
225 | 229 |
|
@@ -898,11 +902,12 @@ def eigvals(a):
|
898 | 902 |
|
899 | 903 | """
|
900 | 904 | a, wrap = _makearray(a)
|
901 |
| - _assertNoEmpty2d(a) |
902 | 905 | _assertRankAtLeast2(a)
|
903 | 906 | _assertNdSquareness(a)
|
904 | 907 | _assertFinite(a)
|
905 | 908 | t, result_t = _commonType(a)
|
| 909 | + if _isEmpty2d(a): |
| 910 | + return empty(a.shape[-1:], dtype=result_t) |
906 | 911 |
|
907 | 912 | extobj = get_linalg_error_extobj(
|
908 | 913 | _raise_linalgerror_eigenvalues_nonconvergence)
|
@@ -1002,10 +1007,11 @@ def eigvalsh(a, UPLO='L'):
|
1002 | 1007 | gufunc = _umath_linalg.eigvalsh_up
|
1003 | 1008 |
|
1004 | 1009 | a, wrap = _makearray(a)
|
1005 |
| - _assertNoEmpty2d(a) |
1006 | 1010 | _assertRankAtLeast2(a)
|
1007 | 1011 | _assertNdSquareness(a)
|
1008 | 1012 | t, result_t = _commonType(a)
|
| 1013 | + if _isEmpty2d(a): |
| 1014 | + return empty(a.shape[-1:], dtype=result_t) |
1009 | 1015 | signature = 'D->d' if isComplexType(t) else 'd->d'
|
1010 | 1016 | w = gufunc(a, signature=signature, extobj=extobj)
|
1011 | 1017 | return w.astype(_realType(result_t), copy=False)
|
@@ -1139,11 +1145,14 @@ def eig(a):
|
1139 | 1145 |
|
1140 | 1146 | """
|
1141 | 1147 | a, wrap = _makearray(a)
|
1142 |
| - _assertNoEmpty2d(a) |
1143 | 1148 | _assertRankAtLeast2(a)
|
1144 | 1149 | _assertNdSquareness(a)
|
1145 | 1150 | _assertFinite(a)
|
1146 | 1151 | t, result_t = _commonType(a)
|
| 1152 | + if _isEmpty2d(a): |
| 1153 | + w = empty(a.shape[-1:], dtype=result_t) |
| 1154 | + vt = empty(a.shape, dtype=result_t) |
| 1155 | + return w, wrap(vt) |
1147 | 1156 |
|
1148 | 1157 | extobj = get_linalg_error_extobj(
|
1149 | 1158 | _raise_linalgerror_eigenvalues_nonconvergence)
|
@@ -1280,8 +1289,11 @@ def eigh(a, UPLO='L'):
|
1280 | 1289 | a, wrap = _makearray(a)
|
1281 | 1290 | _assertRankAtLeast2(a)
|
1282 | 1291 | _assertNdSquareness(a)
|
1283 |
| - _assertNoEmpty2d(a) |
1284 | 1292 | t, result_t = _commonType(a)
|
| 1293 | + if _isEmpty2d(a): |
| 1294 | + w = empty(a.shape[-1:], dtype=result_t) |
| 1295 | + vt = empty(a.shape, dtype=result_t) |
| 1296 | + return w, wrap(vt) |
1285 | 1297 |
|
1286 | 1298 | extobj = get_linalg_error_extobj(
|
1287 | 1299 | _raise_linalgerror_eigenvalues_nonconvergence)
|
@@ -1660,7 +1672,9 @@ def pinv(a, rcond=1e-15 ):
|
1660 | 1672 |
|
1661 | 1673 | """
|
1662 | 1674 | a, wrap = _makearray(a)
|
1663 |
| - _assertNoEmpty2d(a) |
| 1675 | + if _isEmpty2d(a): |
| 1676 | + res = empty(a.shape[:-2] + (a.shape[-1], a.shape[-2]), dtype=a.dtype) |
| 1677 | + return wrap(res) |
1664 | 1678 | a = a.conjugate()
|
1665 | 1679 | u, s, vt = svd(a, 0)
|
1666 | 1680 | m = u.shape[0]
|
@@ -1751,11 +1765,15 @@ def slogdet(a):
|
1751 | 1765 |
|
1752 | 1766 | """
|
1753 | 1767 | a = asarray(a)
|
1754 |
| - _assertNoEmpty2d(a) |
1755 | 1768 | _assertRankAtLeast2(a)
|
1756 | 1769 | _assertNdSquareness(a)
|
1757 | 1770 | t, result_t = _commonType(a)
|
1758 | 1771 | real_t = _realType(result_t)
|
| 1772 | + if _isEmpty2d(a): |
| 1773 | + # determinant of empty matrix is 1 |
| 1774 | + sign = ones(a.shape[:-2], dtype=result_t) |
| 1775 | + logdet = zeros(a.shape[:-2], dtype=real_t) |
| 1776 | + return sign, logdet |
1759 | 1777 | signature = 'D->Dd' if isComplexType(t) else 'd->dd'
|
1760 | 1778 | sign, logdet = _umath_linalg.slogdet(a, signature=signature)
|
1761 | 1779 | if isscalar(sign):
|
@@ -1816,10 +1834,12 @@ def det(a):
|
1816 | 1834 |
|
1817 | 1835 | """
|
1818 | 1836 | a = asarray(a)
|
1819 |
| - _assertNoEmpty2d(a) |
1820 | 1837 | _assertRankAtLeast2(a)
|
1821 | 1838 | _assertNdSquareness(a)
|
1822 | 1839 | t, result_t = _commonType(a)
|
| 1840 | + # 0x0 matrices have determinant 1 |
| 1841 | + if _isEmpty2d(a): |
| 1842 | + return ones(a.shape[:-2], dtype=result_t) |
1823 | 1843 | signature = 'D->D' if isComplexType(t) else 'd->d'
|
1824 | 1844 | r = _umath_linalg.det(a, signature=signature)
|
1825 | 1845 | if isscalar(r):
|
|
0 commit comments