10000 Merge pull request #7177 from gfyoung/count_nonzero_axis · numpy/numpy@31a95d9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 31a95d9

Browse files
authored
Merge pull request #7177 from gfyoung/count_nonzero_axis
ENH: added axis param for np.count_nonzero
2 parents bfd91d9 + 0fc9e45 commit 31a95d9

File tree

7 files changed

+225
-42
lines changed

7 files changed

+225
-42
lines changed

benchmarks/benchmarks/bench_core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,27 @@ def time_correlate(self, size1, size2, mode):
106106

107107
def time_convolve(self, size1, size2, mode):
108108
np.convolve(self.x1, self.x2, mode=mode)
109+
110+
111+
class CountNonzero(Benchmark):
112+
param_names = ['numaxes', 'size', 'dtype']
113+
params = [
114+
[1, 2, 3],
115+
[100, 10000, 1000000],
116+
[bool, int, str, object]
117+
]
118+
119+
def setup(self, numaxes, size, dtype):
120+
self.x = np.empty(shape=(
121+
numaxes, size), dtype=dtype)
122+
123+
def time_count_nonzero(self, numaxes, size, dtype):
124+
np.count_nonzero(self.x)
125+
126+
def time_count_nonzero_axis(self, numaxes, size, dtype):
127+
np.count_nonzero(self.x, axis=self.x.ndim - 1)
128+
129+
def time_count_nonzero_multi_axis(self, numaxes, size, dtype):
130+
if self.x.ndim >= 2:
131+
np.count_nonzero(self.x, axis=(
132+
self.x.ndim - 1, self.x.ndim - 2))

benchmarks/benchmarks/bench_ufunc.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ def setup(self):
6767
def time_nonzero(self):
6868
np.nonzero(self.b)
6969

70-
def time_count_nonzero(self):
71-
np.count_nonzero(self.b)
72-
7370
def time_not_bool(self):
7471
(~self.b)
7572

doc/release/1.12.0-notes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ Generalized ``flip``
138138
axis=1 respectively. The newly added ``flip`` function reverses the elements of
139139
an array along any given axis.
140140

141+
* ``np.count_nonzero`` now has an ``axis`` parameter, allowing
142+
non-zero counts to be generated on more than just a flattened
143+
array object.
141144

142145
BLIS support in ``numpy.distutils``
143146
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

numpy/add_newdocs.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -942,34 +942,6 @@ def luf(lamdaexpr, *args, **kwargs):
942942
943943
""")
944944

945-
add_newdoc('numpy.core.multiarray', 'count_nonzero',
946-
"""
947-
count_nonzero(a)
948-
949-
Counts the number of non-zero values in the array ``a``.
950-
951-
Parameters
952-
----------
953-
a : array_like
954-
The array for which to count non-zeros.
955-
956-
Returns
957-
-------
958-
count : int or array of int
959-
Number of non-zero values in the array.
960-
961-
See Also
962-
--------
963-
nonzero : Return the coordinates of all the non-zero values.
964-
965-
Examples
966-
--------
967-
>>> np.count_nonzero(np.eye(4))
968-
4
969-
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
970-
5
971-
""")
972-
973945
add_newdoc('numpy.core.multiarray', 'set_typeDict',
974946
"""set_typeDict(dict)
975947

numpy/core/numeric.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import warnings
88

9+
import numpy as np
910
from . import multiarray
1011
from .multiarray import (
1112
_fastCopyAndTranspose as fastCopyAndTranspose, ALLOW_THREADS,
@@ -376,6 +377,89 @@ def extend_all(module):
376377
__all__.append(a)
377378

378379

380+
def count_nonzero(a, axis=None):
381+
"""
382+
Counts the number of non-zero values in the array ``a``.
383+
384+
The word "non-zero" is in reference to the Python 2.x
385+
built-in method ``__nonzero__()`` (renamed ``__bool__()``
386+
in Python 3.x) of Python objects that tests an object's
387+
"truthfulness". For example, any number is considered
388+
truthful if it is nonzero, whereas any string is considered
389+
truthful if it is not the empty string. Thus, this function
390+
(recursively) counts how many elements in ``a`` (and in
391+
sub-arrays thereof) have their ``__nonzero__()`` or ``__bool__()``
392+
method evaluated to ``True``.
393+
394+
Parameters
395+
----------
396+
a : array_like
397+
The array for which to count non-zeros.
398+
axis : int or tuple, optional
399+
Axis or tuple of axes along which to count non-zeros.
400+
Default is None, meaning that non-zeros will be counted
401+
along a flattened version of ``a``.
402+
403+
.. versionadded:: 1.12.0
404+
405+
Returns
406+
-------
407+
count : int or array of int
408+
Number of non-zero values in the array along a given axis.
409+
Otherwise, the total number of non-zero values in the array
410+
is returned.
411+
412+
See Also
413+
--------
414+
nonzero : Return the coordinates of all the non-zero values.
415+
416+
Examples
417+
--------
418+
>>> np.count_nonzero(np.eye(4))
419+
4
420+
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
421+
5
422+
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]], axis=0)
423+
array([1, 1, 1, 1, 1])
424+
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]], axis=1)
425+
array([2, 3])
426+
427+
"""
428+
if axis is None or axis == ():
429+
return multiarray.count_nonzero(a)
430+
431+
a = asanyarray(a)
432+
433+
if a.dtype == bool:
434+
return a.sum(axis=axis, dtype=np.intp)
435+
436+
if issubdtype(a.dtype, np.number):
437+
return (a != 0).sum(axis=axis, dtype=np.intp)
438+
439+
if (issubdtype(a.dtype, np.string_) or
440+
issubdtype(a.dtype, np.unicode_)):
441+
nullstr = a.dtype.type('')
442+
return (a != nullstr).sum(axis=axis, dtype=np.intp)
443+
444+
axis = asarray(_validate_axis(axis, a.ndim, 'axis'))
445+
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
446+
447+
if axis.size == 1:
448+
return counts
449+
else:
450+
# for subsequent axis numbers, that number decreases
451+
# by one in this new 'counts' array if it was larger
452+
# than the first axis upon which 'count_nonzero' was
453+
# applied but remains unchanged if that number was
454+
# smaller than that first axis
455+
#
456+
# this trick enables us to perform counts on object-like
457+
# elements across multiple axes very quickly because integer
458+
# addition is very well optimized
459+
return counts.sum(axis=tuple(axis[1:] - (
460+
axis[1:] > axis[0])), dtype=np.intp)
461+
462+
379463
def asarray(a, dtype=None, order=None):
380464
"""Convert the input to an array.
381465
@@ -891,7 +975,7 @@ def correlate(a, v, mode='valid'):
891975
return multiarray.correlate2(a, v, mode)
892976

893977

894-
def convolve(a,v,mode='full'):
978+
def convolve(a, v, mode='full'):
895979
"""
896980
Returns the discrete, linear convolution of two one-dimensional sequences.
897981
@@ -1752,7 +1836,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
17521836
return rollaxis(cp, -1, axisc)
17531837

17541838

1755-
#Use numarray's printing function
1839+
# Use numarray's printing function
17561840
from .arrayprint import array2string, get_printoptions, set_printoptions
17571841

17581842

@@ -2283,6 +2367,7 @@ def load(file):
22832367
# These are all essentially abbreviations
22842368
# These might wind up in a special abbreviations module
22852369

2370+
22862371
def _maketup(descr, val):
22872372
dt = dtype(descr)
22882373
# Place val in all scalar tuples:

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1980,16 +1980,10 @@ array_zeros(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
19801980
static PyObject *
19811981
array_count_nonzero(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds)
19821982
{
1983-
PyObject *array_in;
19841983
PyArrayObject *array;
19851984
npy_intp count;
19861985

1987-
if (!PyArg_ParseTuple(args, "O", &array_in)) {
1988-
return NULL;
1989-
}
1990-
1991-
array = (PyArrayObject *)PyArray_FromAny(array_in, NULL, 0, 0, 0, NULL);
1992-
if (array == NULL) {
1986+
if (!PyArg_ParseTuple(args, "O&", PyArray_Converter, &array)) {
19931987
return NULL;
19941988
}
19951989

numpy/core/tests/test_numeric.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ def test_compress(self):
6767
out = np.compress([0, 1], arr, axis=0)
6868
assert_equal(out, tgt)
6969

70+
def test_count_nonzero(self):
71+
arr = [[0, 1, 7, 0, 0],
72+
[3, 0, 0, 2, 19]]
73+
tgt = np.array([2, 3])
74+
out = np.count_nonzero(arr, axis=1)
75+
assert_equal(out, tgt)
76+
7077
def test_cumproduct(self):
7178
A = [[1, 2, 3], [4, 5, 6]]
7279
assert_(np.all(np.cumproduct(A) == np.array([1, 2, 6, 24, 120, 720])))
@@ -991,9 +998,110 @@ class C(np.ndarray):
991998
assert_(type(nzx_i) is np.ndarray)
992999
assert_(nzx_i.flags.writeable)
9931000

994-
# Tests that the array method
995-
# call works
1001+
def test_count_nonzero_axis(self):
1002+
# Basic check of functionality
1003+
m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
1004+
1005+
expected = np.array([1, 1, 1, 1, 1])
1006+
assert_equal(np.count_nonzero(m, axis=0), expected)
1007+
1008+
expected = np.array([2, 3])
1009+
assert_equal(np.count_nonzero(m, axis=1), expected)
1010+
1011+
assert_raises(ValueError, np.count_nonzero, m, axis=(1, 1))
1012+
assert_raises(TypeError, np.count_nonzero, m, axis='foo')
1013+
assert_raises(ValueError, np.count_nonzero, m, axis=3)
1014+
assert_raises(TypeError, np.count_nonzero,
1015+
m, axis=np.array([[1], [2]]))
1016+
1017+
def test_count_nonzero_axis_all_dtypes(self):
1018+
# More thorough test that the axis argument is respected
1019+
# for all dtypes and responds correctly when presented with
1020+
# either integer or tuple arguments for axis
1021+
msg = "Mismatch for dtype: %s"
1022+
1023+
for dt in np.typecodes['All']:
1024+
err_msg = msg % (np.dtype(dt).name,)
1025+
1026+
if dt != 'V':
1027+
if dt != 'M':
1028+
m = np.zeros((3, 3), dtype=dt)
1029+
n = np.ones(1, dtype=dt)
1030+
1031+
m[0, 0] = n[0]
1032+
m[1, 0] = n[0]
1033+
1034+
else: # np.zeros doesn't work for np.datetime64
1035+
m = np.array(['1970-01-01'] * 9)
1036+
m = m.reshape((3, 3))
1037+
1038+
m[0, 0] = '1970-01-12'
1039+
m[1, 0] = '1970-01-12'
1040+
m = m.astype(dt)
1041+
1042+
expected = np.array([2, 0, 0])
1043+
assert_equal(np.count_nonzero(m, axis=0),
1044+
expected, err_msg=err_msg)
1045+
1046+
expected = np.array([1, 1, 0])
1047+
assert_equal(np.count_nonzero(m, axis=1),
1048+
expected, err_msg=err_msg)
1049+
1050+
expected = np.array(2)
1051+
assert_equal(np.count_nonzero(m, axis=(0, 1)),
1052+
expected, err_msg=err_msg)
1053+
assert_equal(np.count_nonzero(m, axis=None),
1054+
expected, err_msg=err_msg)
1055+
assert_equal(np.count_nonzero(m),
1056+
expected, err_msg=err_msg)
1057+
1058+
if dt == 'V':
1059+
# There are no 'nonzero' objects for np.void, so the testing
1060+
# setup is slightly different for this dtype
1061+
m = np.array([np.void(1)] * 6).reshape((2, 3))
1062+
1063+
expected = np.array([0, 0, 0])
1064+
assert_equal(np.count_nonzero(m, axis=0),
1065+
expected, err_msg=err_msg)
1066+
1067+
expected = np.array([0, 0])
1068+
assert_equal(np.count_nonzero(m, axis=1),
1069+
expected, err_msg=err_msg)
1070+
1071+
expected = np.array(0)
1072+
assert_equal(np.count_nonzero(m, axis=(0, 1)),
1073+
expected, err_msg=err_msg)
1074+
assert_equal(np.count_nonzero(m, axis=None),
1075+
expected, err_msg=err_msg)
1076+
assert_equal(np.count_nonzero(m),
1077+
expected, err_msg=err_msg)
1078+
1079+
def test_count_nonzero_axis_consistent(self):
1080+
# Check that the axis behaviour for valid axes in
1081+
# non-special cases is consistent (and therefore
1082+
# correct) by checking it against an integer array
1083+
# that is then casted to the generic object dtype
1084+
from itertools import combinations, permutations
1085+
1086+
axis = (0, 1, 2, 3)
1087+
size = (5, 5, 5, 5)
1088+
msg = "Mismatch for axis: %s"
1089+
1090+
rng = np.random.RandomState(1234)
1091+
m = rng.randint(-100, 100, size=size)
1092+
n = m.astype(np.object)
1093+
1094+
for length in range(len(axis)):
1095+
for combo in combinations(axis, length):
1096+
for perm in permutations(combo):
1097+
assert_equal(
1098+
np.count_nonzero(m, axis=perm),
1099+
np.count_nonzero(n, axis=perm),
1100+
err_msg=msg % (perm,))
1101+
9961102
def test_array_method(self):
1103+
# Tests that the array method
1104+
# call to nonzero works
9971105
m = np.array([[1, 0, 0], [4, 0, 6]])
9981106
tgt = [[0, 1, 1], [0, 0, 2]]
9991107

0 commit comments

Comments
 (0)
0