8000 ENH: Always produce a consistent shape in the result of `argwhere` by eric-wieser · Pull Request #13610 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: Always produce a consistent shape in the result of argwhere #13610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
ENH: Always produce a consistent shape in the result of argwhere
Previously this would return 1d indices even though the array is zero-d.

Note that using atleast1d inside numeric required an import change to avoid a circular import.
  • Loading branch information
eric-wieser committed Sep 6, 2019
commit b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2
5 changes: 5 additions & 0 deletions doc/release/upcoming_changes/13610.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
``argwhere`` now produces a consistent result on 0d arrays
----------------------------------------------------------
On N-d arrays, `numpy.argwhere` now always produces an array of shape
``(n_non_zero, arr.ndim)``, even when ``arr.ndim == 0``. Previously, the
last axis would have a dimension of 1 in this case.
13 changes: 11 additions &am 10000 p; 2 deletions numpy/core/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from . import overrides
from . import umath
from . import shape_base
from .overrides import set_module
from .umath import (multiply, invert, sin, PINF, NAN)
from . import numerictypes
Expand Down Expand Up @@ -545,16 +546,19 @@ def argwhere(a):

Returns
-------
index_array : ndarray
index_array : (N, a.ndim) ndarray
Indices of elements that are non-zero. Indices are grouped by element.
This array will have shape ``(N, a.ndim)`` where ``N`` is the number of
non-zero items.

See Also
--------
where, nonzero

Notes
-----
``np.argwhere(a)`` is the same as ``np.transpose(np.nonzero(a))``.
``np.argwhere(a)`` is almost the same as ``np.transpose(np.nonzero(a))``,
but produces a result of the correct shape for a 0D array.

The output of ``argwhere`` is not suitable for indexing arrays.
For this purpose use ``nonzero(a)`` instead.
Expand All @@ -572,6 +576,11 @@ def argwhere(a):
[1, 2]])

"""
# nonzero does not behave well on 0d, so promote to 1d
if np.ndim(a) == 0:
a = shape_base.atleast_1d(a)
# then remove the added dimension
return argwhere(a)[:,:0]
return transpose(nonzero(a))


Expand Down
15 changes: 8 additions & 7 deletions numpy/core/shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from . import numeric as _nx
from . import overrides
from .numeric import array, asanyarray, newaxis
from ._asarray import array, asanyarray
from .multiarray import normalize_axis_index
from . import fromnumeric as _from_nx


array_function_dispatch = functools.partial(
Expand Down Expand Up @@ -123,7 +124,7 @@ def atleast_2d(*arys):
if ary.ndim == 0:
result = ary.reshape(1, 1)
elif ary.ndim == 1:
result = ary[newaxis, :]
result = ary[_nx.newaxis, :]
else:
result = ary
res.append(result)
Expand Down Expand Up @@ -193,9 +194,9 @@ def atleast_3d(*arys):
if ary.ndim == 0:
result = ary.reshape(1, 1, 1)
elif ary.ndim == 1:
result = ary[newaxis, :, newaxis]
result = ary[_nx.newaxis, :, _nx.newaxis]
elif ary.ndim == 2:
result = ary[:, :, newaxis]
result = ary[:, :, _nx.newaxis]
else:
result = ary
res.append(result)
Expand Down Expand Up @@ -435,9 +436,9 @@ def stack(arrays, axis=0, out=None):
# Internal functions to eliminate the overhead of repeated dispatch in one of
# the two possible paths inside np.block.
# Use getattr to protect against __array_function__ being disabled.
_size = getattr(_nx.size, '__wrapped__', _nx.size)
_ndim = getattr(_nx.ndim, '__wrapped__', _nx.ndim)
_concatenate = getattr(_nx.concatenate, '__wrapped__', _nx.concatenate)
_size = getattr(_from_nx.size, '__wrapped__', _from_nx.size)
_ndim = getattr(_from_nx.ndim, '__wrapped__', _from_nx.ndim)
_concatenate = getattr(_from_nx.concatenate, '__wrapped__', _from_nx.concatenate)


def _block_format_index(index):
Expand Down
24 changes: 24 additions & 0 deletions numpy/core/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2583,6 +2583,30 @@ def test_no_overwrite(self):


class TestArgwhere(object):

@pytest.mark.parametrize('nd', [0, 1, 2])
def test_nd(self, nd):
# get an nd array with multiple elements in every dimension
x = np.empty((2,)*nd, bool)

# none
x[...] = False
assert_equal(np.argwhere(x).shape, (0, nd))

# only one
x[...] = False
x.flat[0] = True
assert_equal(np.argwhere(x).shape, (1, nd))

# all but one
x[...] = True
x.flat[0] = False
assert_equal(np.argwhere(x).shape, (x.size - 1, nd))

# all
x[...] = True
assert_equal(np.argwhere(x).shape, (x.size, nd))

def test_2D(self):
x = np.arange(6).reshape((2, 3))
assert_array_equal(np.argwhere(x > 1),
3A2E Expand Down
0