10000 BUG: reshape inverse_values output for multi-dimensional np.unique · numpy/numpy@6903f6c · GitHub
[go: up one dir, main page]

Skip to content

Commit 6903f6c

Browse files 10000
committed
BUG: reshape inverse_values output for multi-dimensional np.unique
1 parent 01be917 commit 6903f6c

File tree

3 files changed

+41
-18
lines changed

3 files changed

+41
-18
lines changed

doc/source/release/2.0.0-notes.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,12 @@ precalculated mean as an keyword argument. See the doc-strings for details and a
313313
example illustrating the speed-up.
314314

315315
(`gh-24126 <https://github.com/numpy/numpy/pull/24126>`__)
316+
317+
``np.unique`` ``return_inverse`` shape for multi-dimensional inputs
318+
-------------------------------------------------------------------
319+
When multi-dimensional inputs are passed to ``np.unique`` with ``return_inverse=True``,
320+
the ``unique_inverse`` output is now shaped such that the input can be reconstructed
321+
directly using ``np.take(unique, unique_inverse)`` when ``axis = None``, and
322+
``np.take_along_axis(unique, unique_inverse, axis=axis)`` otherwise.
323+
324+
(`gh-25553 <https://github.com/numpy/numpy/pull/24126>`__)

numpy/lib/_arraysetops_impl.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ def unique(ar, return_index=False, return_inverse=False,
222222
lexicographical order is chosen - see np.sort for how the lexicographical
223223
order is defined for complex arrays.
224224
225+
.. versionchanged: NumPy 2.0
226+
For multi-dimensional inputs, ``unique_inverse`` is reshaped such that
227+
the input can be reconstructed using ``np.take(unique, unique_inverse)``
228+
when ``axis = None``, and
229+
``np.take_along_axis(unique, unique_inverse, axis=axis)`` otherwise.
230+
225231
Examples
226232
--------
227233
>>> np.unique([1, 1, 2, 2, 3, 3])
@@ -273,7 +279,7 @@ def unique(ar, return_index=False, return_inverse=False,
273279
ar = np.asanyarray(ar)
274280
if axis is None:
275281
ret = _unique1d(ar, return_index, return_inverse, return_counts,
276-
equal_nan=equal_nan)
282+
equal_nan=equal_nan, inverse_shape=ar.shape)
277283
return _unpack_tuple(ret)
278284

279285
# axis was specified and not None
@@ -282,6 +288,8 @@ def unique(ar, return_index=False, return_inverse=False,
282288
except np.exceptions.AxisError:
283289
# this removes the "axis1" or "axis2" prefix from the error message
284290
raise np.exceptions.AxisError(axis, ar.ndim) from None
291+
inverse_shape = [1] * ar.ndim
292+
inverse_shape[axis] = ar.shape[0]
285293

286294
# Must reshape to a contiguous 2D array for this to work...
287295
orig_shape, orig_dtype = ar.shape, ar.dtype
@@ -316,13 +324,14 @@ def reshape_uniq(uniq):
316324
return uniq
317325

318326
output = _unique1d(consolidated, return_index,
319-
return_inverse, return_counts, equal_nan=equal_nan)
327+
return_inverse, return_counts,
328+
equal_nan=equal_nan, inverse_shape=inverse_shape)
320329
output = (reshape_uniq(output[0]),) + output[1:]
321330
return _unpack_tuple(output)
322331

323332

324333
def _unique1d(ar, return_index=False, return_inverse=False,
325-
return_counts=False, *, equal_nan=True):
334+
return_counts=False, *, equal_nan=True, inverse_shape=None):
326335
"""
327336
Find the unique elements of an array, ignoring shape.
328337
"""
@@ -359,7 +368,7 @@ def _unique1d(ar, return_index=False, return_inverse=False,
359368
imask = np.cumsum(mask) - 1
360369
inv_idx = np.empty(mask.shape, dtype=np.intp)
361370
inv_idx[perm] = imask
362-
ret += (inv_idx,)
371+
ret += (inv_idx.reshape(inverse_shape),)
363372
if return_counts:
364373
idx = np.concatenate(np.nonzero(mask) + ([mask.size],))
365374
ret += (np.diff(idx),)
@@ -422,17 +431,14 @@ def unique_all(x):
422431
unique : Find the unique elements of an array.
423432
424433
"""
425-
x = np.asanyarray(x)
426-
values, indices, inverse_indices, counts = unique(
434+
result = unique(
427435
x,
428436
return_index=True,
429437
return_inverse=True,
430438
return_counts=True,
431439
equal_nan=False
432440
)
433-
inverse_indices = inverse_indices.reshape(x.shape)
434-
return UniqueAllResult(values=values, indices=indices,
435-
inverse_indices=inverse_indices, counts=counts)
441+
return UniqueAllResult(*result)
436442

437443

438444
def _unique_counts_dispatcher(x, /):
@@ -512,16 +518,14 @@ def unique_inverse(x):
512518
unique : Find the unique elements of an array.
513519
514520
"""
515-
x = np.asanyarray(x)
516-
values, inverse_indices = unique(
521+
result = unique(
517522
x,
518523
return_index=False,
519524
return_inverse=True,
520525
return_counts=False,
521526
equal_nan=False
522527
)
523-
inverse_indices = inverse_indices.reshape(x.shape)
524-
return UniqueInverseResult(values=values, inverse_indices=inverse_indices)
528+
return UniqueInverseResult(*result)
525529

526530

527531
def _unique_values_dispatcher(x, /):

numpy/lib/tests/test_arraysetops.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def check_all(a, b, i1, i2, c, dt):
686686
# test for ticket #4785
687687
a = [(1, 2), (1, 2), (2, 3)]
688688
unq = [1, 2, 3]
689-
inv = [0, 1, 0, 1, 1, 2]
689+
inv = [[0, 1], [0, 1], [1, 2]]
690690
a1 = unique(a)
691691
assert_array_equal(a1, unq)
692692
a2, a2_inv = unique(a, return_inverse=True)
@@ -810,6 +810,16 @@ def test_unique_1d_with_axis(self, axis):
810810
uniq = unique(x, axis=axis)
811811
assert_array_equal(uniq, [1, 2, 3, 4])
812812

813+
@pytest.mark.parametrize("axis", [None, 0, -1])
814+
def test_unique_inverse_with_axis(self, axis):
815+
x = np.array([[4, 4, 3], [2, 2, 1], [2, 2, 1], [4, 4, 3]])
816+
uniq, inv = unique(x, return_inverse=True, axis=axis)
817+
assert_equal(inv.ndim, x.ndim)
818+
if axis is None:
819+
assert_array_equal(x, np.take(uniq, inv))
820+
else:
821+
assert_array_equal(x, np.take_along_axis(uniq, inv, axis=axis))
822+
813823
def test_unique_axis_zeros(self):
814824
# issue 15559
815825
single_zero = np.empty(shape=(2, 0), dtype=np.int8)
@@ -820,7 +830,7 @@ def test_unique_axis_zeros(self):
820830
assert_equal(uniq.dtype, single_zero.dtype)
821831
assert_array_equal(uniq, np.empty(shape=(1, 0)))
822832
assert_array_equal(idx, np.array([0]))
823-
assert_array_equal(inv, np.array([0, 0]))
833+
assert_array_equal(inv, np.array([[0], [0]]))
824834
assert_array_equal(cnt, np.array([2]))
825835

826836
# there's 0 elements of shape (2,) along axis 1
@@ -830,7 +840,7 @@ def test_unique_axis_zeros(self):
830840
assert_equal(uniq.dtype, single_zero.dtype)
831841
assert_array_equal(uniq, np.empty(shape=(2, 0)))
832842
assert_array_equal(idx, np.array([]))
833-
assert_array_equal(inv, np.array([]))
843+
assert_array_equal(inv, np.empty((1, 0)))
834844
assert_array_equal(cnt, np.array([]))
835845

836846
# test a "complicated" shape
@@ -899,7 +909,7 @@ def _run_axis_tests(self, dtype):
899909
msg = "Unique's return_index=True failed with axis=0"
900910
assert_array_equal(data[idx], uniq, msg)
901911
msg = "Unique's return_inverse=True failed with axis=0"
902-
assert_array_equal(uniq[inv], data)
912+
assert_array_equal(np.take_along_axis(uniq, inv, axis=0), data)
903913
msg = "Unique's return_counts=True failed with axis=0"
904914
assert_array_equal(cnt, np.array([2, 2]), msg)
905915

@@ -908,7 +918,7 @@ def _run_axis_tests(self, dtype):
908918
msg = "Unique's return_index=True failed with axis=1"
909919
assert_array_equal(data[:, idx], uniq)
910920
msg = "Unique's return_inverse=True failed with axis=1"
911-
assert_array_equal(uniq[:, inv], data)
921+
assert_array_equal(np.take_along_axis(uniq, inv, axis=1), data)
912922
msg = "Unique's return_counts=True failed with axis=1"
913923
assert_array_equal(cnt, np.array([2, 1, 1]), msg)
914924

0 commit comments

Comments
 (0)
0