8000 Merge pull request #3861 from seberg/nditer-remove-empty · numpy/numpy@8d01ee4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8d01ee4

Browse files
authored
Merge pull request #3861 from seberg/nditer-remove-empty
ENH: Make it possible to NpyIter_RemoveAxis an empty dimension
2 parents af54fbd + 1c4e0d4 commit 8d01ee4

File tree

7 files changed

+178
-81
lines changed

7 files changed

+178
-81
lines changed

doc/release/1.13.0-notes.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,24 @@ now issue a DeprecationWarning - ``.__getitem__(slice(start, end))`` should be
9999
used instead.
100100

101101

102+
C API
103+
-----
104+
105+
GUfuncs on empty arrays and NpyIter axis removal
106+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
107+
It is now allowed to remove a zero-sized axis from NpyIter. Which may mean
108+
that code removing axes from NpyIter has to add an additional check when
109+
accessing the removed dimensions later on.
110+
111+
The largest followup change is that gufuncs are now allowed to have zero-sized
112+
inner dimensions. This means that a gufunc now has to anticipate an empty inner
113+
dimension, while this was never possible and an error raised instead.
114+
115+
For most gufuncs no change should be necessary. However, it is now possible
116+
for gufuncs with a signature such as ``(..., N, M) -> (..., M)`` to return
117+
a valid result if ``N=0`` without further wrapping code.
118+
119+
102120
New Features
103121
============
104122

numpy/core/src/multiarray/nditer_api.c

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,6 @@ NpyIter_RemoveAxis(NpyIter *iter, int axis)
106106
return NPY_FAIL;
107107
}
108108

109-
if (NAD_SHAPE(axisdata_del) == 0) {
110-
PyErr_SetString(PyExc_ValueError,
111-
"cannot remove a zero-sized axis from an iterator");
112-
return NPY_FAIL;
113-
}
114-
115109
/* Adjust the permutation */
116110
for (idim = 0; idim < ndim-1; ++idim) {
117111
npy_int8 p = (idim < xdim) ? perm[idim] : perm[idim+1];

numpy/core/src/umath/ufunc_object.c

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3269,14 +3269,6 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
32693269
op[0] = NpyIter_GetOperandArray(iter)[0];
32703270
op[1] = NpyIter_GetOperandArray(iter)[1];
32713271

3272-
if (PyArray_SIZE(op[0]) == 0) {
3273-
if (out == NULL) {
3274-
out = op[0];
3275-
Py_INCREF(out);
3276-
}
3277-
goto finish;
3278-
}
3279-
32803272
if (NpyIter_RemoveAxis(iter, axis) != NPY_SUCCEED) {
32813273
goto fail;
32823274
}
@@ -3623,35 +3615,6 @@ PyUFunc_Reduceat(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *ind,
36233615
need_outer_iterator = 1;
36243616
}
36253617

3626-
/* Special case when the index array's size is zero */
3627-
if (ind_size == 0) {
3628-
if (out == NULL) {
3629-
npy_intp out_shape[NPY_MAXDIMS];
3630-
memcpy(out_shape, PyArray_SHAPE(arr),
3631-
PyArray_NDIM(arr) * NPY_SIZEOF_INTP);
3632-
out_shape[axis] = 0;
3633-
Py_INCREF(op_dtypes[0]);
3634-
op[0] = out = (PyArrayObject *)PyArray_NewFromDescr(
3635-
&PyArray_Type, op_dtypes[0],
3636-
PyArray_NDIM(arr), out_shape, NULL, NULL,
3637-
0, NULL);
3638-
if (out == NULL) {
3639-
goto fail;
3640-
}
3641-
}
3642-
else {
3643-
/* Allow any zero-sized output array in this case */
3644-
if (PyArray_SIZE(out) != 0) {
3645-
PyErr_SetString(PyExc_ValueError,
3646-
"output operand shape for reduceat is "
3647-
"incompatible with index array of shape (0,)");
3648-
goto fail;
3649-
}
3650-
}
3651-
3652-
goto finish;
3653-
}
3654-
36553618
if (need_outer_iterator) {
36563619
npy_uint32 flags = NPY_ITER_ZEROSIZE_OK|
36573620
NPY_ITER_REFS_OK|

numpy/core/src/umath/umath_tests.c.src

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,20 @@ static void
155155
npy_intp ib2_n = is2_n*dn;
156156
npy_intp ib2_p = is2_p*dp;
157157
npy_intp ob_p = os_p *dp;
158+
if (dn == 0) {
159+
/* No operand, need to zero the output */
160+
BEGIN_OUTER_LOOP_3
161+
char *op=args[2];
162+
for (m = 0; m < dm; m++) {
163+
for (p = 0; p < dp; p++) {
164+
*(@typ@ *)op = 0;
165+
op += os_p;
166+
}
167+
op += os_m - ob_p;
168+
}
169+
END_OUTER_LOOP
170+
return;
171+
}
158172
BEGIN_OUTER_LOOP_3
159173
char *ip1=args[0], *ip2=args[1], *op=args[2];
160174
for (m = 0; m < dm; m++) {

numpy/core/tests/test_ufunc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,12 @@ def test_matrix_multiply(self):
542542
self.compare_matrix_multiply_results(np.long)
543543
self.compare_matrix_multiply_results(np.double)
544544

545+
def test_matrix_multiply_umath_empty(self):
546+
res = umt.matrix_multiply(np.ones((0, 10)), np.ones((10, 0)))
547+
assert_array_equal(res, np.zeros((0, 0)))
548+
res = umt.matrix_multiply(np.ones((10, 0)), np.ones((0, 10)))
549+
assert_array_equal(res, np.zeros((10, 10)))
550+
545551
def compare_matrix_multiply_results(self, tp):
546552
d1 = np.array(np.random.rand(2, 3, 4), dtype=tp)
547553
d2 = np.array(np.random.rand(2, 3, 4), dtype=tp)

numpy/linalg/linalg.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -366,21 +366,8 @@ def solve(a, b):
366366
# We use the b = (..., M,) logic, only if the number of extra dimensions
367367
# match exactly
368368
if b.ndim == a.ndim - 1:
369-
if a.shape[-1] == 0 and b.shape[-1] == 0:
370-
# Legal, but the ufunc cannot handle the 0-sized inner dims
371-
# let the ufunc handle all wrong cases.
372-
a = a.reshape(a.shape[:-1])
373-
bc = broadcast(a, b)
374-
return wrap(empty(bc.shape, dtype=result_t))
375-
376369
gufunc = _umath_linalg.solve1
377370
else:
378-
if b.size == 0:
379-
if (a.shape[-1] == 0 and b.shape[-2] == 0) or b.shape[-1] == 0:
380-
a = a[:,:1].reshape(a.shape[:-1] + (1,))
381-
bc = broadcast(a, b)
382-
return wrap(empty(bc.shape, dtype=result_t))
383-
384371
gufunc = _umath_linalg.solve
385372

386373
signature = 'DD->D' if isComplexType(t) else 'dd->d'
@@ -521,10 +508,6 @@ def inv(a):
521508
_assertNdSquareness(a)
522509
t, result_t = _commonType(a)
523510

524-
if a.shape[-1] == 0:
525-
# The inner array is 0x0, the ufunc cannot handle this case
526-
return wrap(empty_like(a, dtype=result_t))
527-
528511
signature = 'D->D' if isComplexType(t) else 'd->d'
529512
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
530513
ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
@@ -905,8 +888,6 @@ def eigvals(a):
905888
_assertNdSquareness(a)
906889
_assertFinite(a)
907890
t, result_t = _commonType(a)
908-
if _isEmpty2d(a):
909-
return empty(a.shape[-1:], dtype=result_t)
910891

911892
extobj = get_linalg_error_extobj(
912893
_raise_linalgerror_eigenvalues_nonconvergence)
@@ -1009,8 +990,6 @@ def eigvalsh(a, UPLO='L'):
1009990
_assertRankAtLeast2(a)
1010991
_assertNdSquareness(a)
1011992
t, result_t = _commonType(a)
1012-
if _isEmpty2d(a):
1013-
return empty(a.shape[-1:], dtype=result_t)
1014993
signature = 'D->d' if isComplexType(t) else 'd->d'
1015994
w = gufunc(a, signature=signature, extobj=extobj)
1016995
return w.astype(_realType(result_t), copy=False)
@@ -1148,10 +1127,6 @@ def eig(a):
11481127
_assertNdSquareness(a)
11491128
_assertFinite(a)
11501129
t, result_t = _commonType(a)
1151-
if _isEmpty2d(a):
1152-
w = empty(a.shape[-1:], dtype=result_t)
1153-
vt = empty(a.shape, dtype=result_t)
1154-
return w, wrap(vt)
11551130

11561131
extobj = get_linalg_error_extobj(
11571132
_raise_linalgerror_eigenvalues_nonconvergence)
@@ -1289,10 +1264,6 @@ def eigh(a, UPLO='L'):
12891264
_assertRankAtLeast2(a)
12901265
_assertNdSquareness(a)
12911266
t, result_t = _commonType(a)
1292-
if _isEmpty2d(a):
1293-
w = empty(a.shape[-1:], dtype=result_t)
1294-
vt = empty(a.shape, dtype=result_t)
1295-
return w, wrap(vt)
12961267

12971268
extobj = get_linalg_error_extobj(
12981269
_raise_linalgerror_eigenvalues_nonconvergence)
@@ -1766,11 +1737,6 @@ def slogdet(a):
17661737
_assertNdSquareness(a)
17671738
t, result_t = _commonType(a)
17681739
real_t = _realType(result_t)
1769-
if _isEmpty2d(a):
1770-
# determinant of empty matrix is 1
1771-
sign = ones(a.shape[:-2], dtype=result_t)
1772-
logdet = zeros(a.shape[:-2], dtype=real_t)
1773-
return sign, logdet
17741740
signature = 'D->Dd' if isComplexType(t) else 'd->dd'
17751741
sign, logdet = _umath_linalg.slogdet(a, signature=signature)
17761742
if isscalar(sign):
@@ -1834,9 +1800,6 @@ def det(a):
18341800
_assertRankAtLeast2(a)
18351801
_assertNdSquareness(a)
18361802
t, result_t = _commonType(a)
1837-
# 0x0 matrices have determinant 1
1838-
if _isEmpty2d(a):
1839-
return ones(a.shape[:-2], dtype=result_t)
18401803
signature = 'D->D' if isComplexType(t) else 'd->d'
18411804
r = _umath_linalg.det(a, signature=signature)
18421805
if isscalar(r):

0 commit comments

Comments
 (0)
0