8000 API: Add Array API setops · numpy/numpy@e05149a · GitHub
[go: up one dir, main page]

Skip to content

Commit e05149a

Browse files
committed
API: Add Array API setops
1 parent 29cbb1f commit e05149a

File tree

8 files changed

+282
-9
lines changed

8 files changed

+282
-9
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Array API set functions
2+
-----------------------
3+
4+
`numpy.unique_all`, `numpy.unique_counts`, `numpy.unique_inverse`,
5+
and `numpy.unique_values` functions have been added for Array API compatiblity.
6+
They provide functionality of `numpy.unique` with different sets of flags.

doc/source/reference/routines.set.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ Making proper sets
99
:toctree: generated/
1010

1111
unique
12+
unique_all
13+
unique_counts
14+
unique_inverse
15+
unique_values
1216

1317
Boolean operations
1418
------------------

numpy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@
207207
real_if_close, typename, mintypecode, common_type
208208
)
209209
from .lib._arraysetops_impl import (
210-
ediff1d, in1d, intersect1d, isin, setdiff1d, setxor1d, union1d, unique
210+
ediff1d, in1d, intersect1d, isin, setdiff1d, setxor1d, union1d,
211+
unique, unique_all, unique_counts, unique_inverse, unique_values
211212
)
212213
from .lib._ufunclike_impl import fix, isneginf, isposinf
213214
from .lib._arraypad_impl import pad

numpy/__init__.pyi

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,17 @@ from numpy.lib._arraypad_impl import (
405405

406406
from numpy.lib._arraysetops_impl import (
407407
ediff1d as ediff1d,
408+
in1d as in1d,
408409
intersect1d as intersect1d,
410+
isin as isin,
411+
setdiff1d as setdiff1d,
409412
setxor1d as setxor1d,
410413
union1d as union1d,
411-
setdiff1d as setdiff1d,
412414
unique as unique,
413-
in1d as in1d,
414-
isin as isin,
415+
unique_all as unique_all,
416+
unique_counts as unique_counts,
417+
unique_inverse as unique_inverse,
418+
unique_values as unique_values,
415419
)
416420

417421
from numpy.lib._function_base_impl import (

numpy/lib/_arraysetops_impl.py

Lines changed: 192 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717
import functools
1818
import warnings
19+
from typing import NamedTuple
1920

2021
import numpy as np
2122
from numpy._core import overrides
@@ -26,9 +27,10 @@
2627

2728

2829
__all__ = [
29-
'ediff1d', 'intersect1d', 'setxor1d', 'union1d', 'setdiff1d', 'unique',
30-
'in1d', 'isin'
31-
]
30+
"ediff1d", "in1d", "intersect1d", "isin", "setdiff1d", "setxor1d",
31+
"union1d", "unique", "unique_all", "unique_counts", "unique_inverse",
32+
"unique_values"
33+
]
3234

3335

3436
def _ediff1d_dispatcher(ary, to_end=None, to_begin=None):
@@ -364,6 +366,193 @@ def _unique1d(ar, return_index=False, return_inverse=False,
364366
return ret
365367

366368

369+
# Array API set functions
370+
371+
class UniqueAllResult(NamedTuple):
372+
values: np.ndarray
373+
indices: np.ndarray
374+
inverse_indices: np.ndarray
375+
counts: np.ndarray
376+
377+
378+
class UniqueCountsResult(NamedTuple):
379+
values: np.ndarray
380+
counts: np.ndarray
381+
382+
383+
class UniqueInverseResult(NamedTuple):
384+
values: np.ndarray
385+
inverse_indices: np.ndarray
386+
387+
388+
UniqueAllResult.__module__ = "numpy"
389+
UniqueCountsResult.__module__ = "numpy"
390+
UniqueInverseResult.__module__ = "numpy"
391+
392+
393+
def _unique_all_dispatcher(x, /):
394+
return (x,)
395+
396+
397+
@array_function_dispatch(_unique_all_dispatcher)
398+
def unique_all(x):
399+
"""
400+
Returns the unique elements of an input array `x`, the first occurring
401+
indices for each unique element in `x`, the indices from the set of unique
402+
elements that reconstruct `x`, and the corresponding counts for each
403+
unique element in `x`.
404+
405+
This function is Array API compatible alternative to:
406+
``np.unique(x, return_index=True, return_inverse=True, return_counts=True,
407+
equal_nan=False)``
408+
409+
Parameters
410+
----------
411+
x : array_like
412+
Input array. It will be flattened if it is not already 1-D.
413+
414+
Returns
415+
-------
416+
out : namedtuple
417+
The result containing:
418+
* values - The unique elements of an input array.
419+
* indices - The first occurring indices for each unique element.
420+
* inverse_indices - The indices from the set of unique elements
421+
that reconstruct `x`.
422+
* counts - The corresponding counts for each unique element.
423+
424+
See Also
425+
--------
426+
unique : Find the unique elements of an array.
427+
428+
"""
429+
result = unique(
430+
x,
431+
return_index=True,
432+
return_inverse=True,
433+
return_counts=True,
434+
equal_nan=False
435+
)
436+
return UniqueAllResult(*result)
437+
438+
439+
def _unique_counts_dispatcher(x, /):
440+
return (x,)
441+
442+
443+
@array_function_dispatch(_unique_counts_dispatcher)
444+
def unique_counts(x):
445+
"""
446+
Returns the unique elements of an input array `x`, and the corresponding
447+
counts for each unique element in `x`.
448+
449+
This function is Array API compatible alternative to:
450+
``np.unique(x, return_counts=True, equal_nan=False)``
451+
452+
Parameters
453+
----------
454+
x : array_like
455+
Input array. It will be flattened if it is not already 1-D.
456+
457+
Returns
458+
-------
459+
out : namedtuple
460+
The result containing:
461+
* values - The unique elements of an input array.
462+
* counts - The corresponding counts for each unique element.
463+
464+
See Also
465+
--------
466+
unique : Find the unique elements of an array.
467+
468+
"""
469+
result = unique(
470+
x,
471+
return_index=False,
472+
return_inverse=False,
473+
return_counts=True,
474+
equal_nan=False
475+
)
476+
return UniqueCountsResult(*result)
477+
478+
479+
def _unique_inverse_dispatcher(x, /):
480+
return (x,)
481+
482+
483+
@array_function_dispatch(_unique_inverse_dispatcher)
484+
def unique_inverse(x):
485+
"""
486+
Returns the unique elements of an input array `x` and the indices
487+
from the set of unique elements that reconstruct `x`.
488+
489+
This function is Array API compatible alternative to:
490+
``np.unique(x, return_inverse=True, equal_nan=False)``
491+
492+
Parameters
493+
----------
494+
x : array_like
495+
Input array. It will be flattened if it is not already 1-D.
496+
497+
Returns
498+
-------
499+
out : namedtuple
500+
The result containing:
501+
* values - The unique elements of an input array.
502+
* inverse_indices - The indices from the set of unique elements
503+
that reconstruct `x`.
504+
505+
See Also
506+
--------
507+
unique : Find the unique elements of an array.
508+
509+
"""
510+
result = unique(
511+
x,
512+
return_index=False,
513+
return_inverse=True,
514+
return_counts=False,
515+
equal_nan=False
516+
)
517+
return UniqueInverseResult(*result)
518+
519+
520+
def _unique_values_dispatcher(x, /):
521+
return (x,)
522+
523+
524+
@array_function_dispatch(_unique_values_dispatcher)
525+
def unique_values(x):
526+
"""
527+
Returns the unique elements of an input array `x`.
528+
529+
This function is Array API compatible alternative to:
530+
``np.unique(x, equal_nan=False)``
531+
532+
Parameters
533+
----------
534+
x : array_like
535+
Input array. It will be flattened if it is not already 1-D.
536+
537+
Returns
538+
-------
539+
out : ndarray
540+
The unique elements of an input array.
541+
542+
See Also
543+
--------
544+
unique : Find the unique elements of an array.
545+
546+
"""
547+
return unique(
548+
x,
549+
return_index=False,
550+
return_inverse=False,
551+
return_counts=False,
552+
equal_nan=False
553+
)
554+
555+
367556
def _intersect1d_dispatcher(
368557
ar1, ar2, assume_unique=None, return_indices=None):
369558
return (ar1, ar2)

numpy/lib/_arraysetops_impl.pyi

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import (
2-
Literal as L,
32
Any,
4-
TypeVar,
3+
Literal as L,
4+
NamedTuple,
55
overload,
66
SupportsIndex,
7+
TypeVar,
78
)
89

910
from numpy import (
@@ -279,6 +280,34 @@ def unique(
279280
equal_nan: bool = ...,
280281
) -> tuple[NDArray[Any], NDArray[intp], NDArray[intp], NDArray[intp]]: ...
281282

283+
@overload
284+
def unique_all(
285+
x: _ArrayLike[_SCT], /
286+
) -> tuple[NDArray[_SCT], NDArray[intp], NDArray[intp], NDArray[intp]]: ...
287+
@overload
288+
def unique_all(
289+
x: ArrayLike, /
290+
) -> tuple[NDArray[Any], NDArray[intp], NDArray[intp], NDArray[intp]]: ...
291+
292+
@overload
293+
def unique_counts(
294+
x: _ArrayLike[_SCT], /
295+
) -> tuple[NDArray[_SCT], NDArray[intp]]: ...
296+
@overload
297+
def unique_counts(
298+
x: ArrayLike, /
299+
) -> tuple[NDArray[Any], NDArray[intp]]: ...
300+
301+
@overload
302+
def unique_inverse(x: _ArrayLike[_SCT], /) -> tuple[NDArray[_SCT], NDArray[intp]]: ...
303+
@overload
304+
def unique_inverse(x: ArrayLike, /) -> tuple[NDArray[Any], NDArray[intp]]: ...
305+
306+
@overload
307+
def unique_values(x: _ArrayLike[_SCT], /) -> NDArray[_SCT]: ...
308+
@overload
309+
def unique_values(x: ArrayLike, /) -> NDArray[Any]: ...
310+
282311
@overload
283312
def intersect1d(
284313
ar1: _ArrayLike[_SCTNoCast],

numpy/lib/tests/test_arraysetops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,3 +919,34 @@ def test_unique_nanequals(self):
919919
not_unq = np.unique(a, equal_nan=False)
920920
assert_array_equal(unq, np.array([1, np.nan]))
921921
assert_array_equal(not_unq, np.array([1, np.nan, np.nan, np.nan]))
922+
923+
def test_unique_array_api_functions(self):
924+
arr = np.array([np.nan, 1, 4, 1, 3, 4, np.nan, 5, 1])
925+
926+
for res_unique_array_api, res_unique in [
927+
(
928+
np.unique_values(arr),
929+
np.unique(arr, equal_nan=False)
930+
),
931+
(
932+
np.unique_counts(arr),
933+
np.unique(arr, return_counts=True, equal_nan=False)
934+
),
935+
(
936+
np.unique_inverse(arr),
937+
np.unique(arr, return_inverse=True, equal_nan=False)
938+
),
939+
(
940+
np.unique_all(arr),
941+
np.unique(
942+
arr,
943+
return_index=True,
944+
return_inverse=True,
945+
return_counts=True,
946+
equal_nan=False
947+
)
948+
)
949+
]:
950+
assert len( B339 res_unique_array_api) == len(res_unique)
951+
for actual, expected in zip(res_unique_array_api, res_unique):
952+
assert_array_equal(actual, expected)

numpy/typing/tests/data/reveal/arraysetops.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,12 @@ assert_type(np.unique(AR_f8, return_inverse=True, return_counts=True), tuple[npt
6161
assert_type(np.unique(AR_LIKE_f8, return_inverse=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp]])
6262
assert_type(np.unique(AR_f8, return_index=True, return_inverse=True, return_counts=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]])
6363
assert_type(np.unique(AR_LIKE_f8, return_index=True, return_inverse=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]])
64+
65+
assert_type(np.unique_all(AR_f8), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]])
66+
assert_type(np.unique_all(AR_LIKE_f8), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]])
67+
assert_type(np.unique_counts(AR_f8), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp]])
68+
assert_type(np.unique_counts(AR_LIKE_f8), tuple[npt.NDArray[Any], npt.NDArray[np.intp]])
69+
assert_type(np.unique_inverse(AR_f8), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp]])
70+
assert_type(np.unique_inverse(AR_LIKE_f8), tuple[npt.NDArray[Any], npt.NDArray[np.intp]])
71+
assert_type(np.unique_values(AR_f8), npt.NDArray[np.float64])
72+
assert_type(np.unique_values(AR_LIKE_f8), npt.NDArray[Any])

0 commit comments

Comments
 (0)
0