8000 API: Add Array API setops · numpy/numpy@f7f9509 · GitHub
[go: up one dir, main page]

Skip to content

Commit f7f9509

Browse files
committed
API: Add Array API setops
1 parent 06d7bdf commit f7f9509

File tree

10 files changed

+295
-21
lines changed

10 files changed

+295
-21
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Array API set functions
2+
-----------------------
3+
4+
`numpy.unique_all`, `numpy.unique_counts`, `num 8000 py.unique_inverse`,
5+
and `numpy.unique_values` functions have been added for Array API compatiblity.
6+
They provide functionality of `numpy.unique` with different sets of flags.

doc/source/reference/array_api.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,6 @@ The following functions are named differently in the array API
119119
* - ``pow``
120120
- ``power``
121121
-
122-
* - ``unique_all``, ``unique_counts``, ``unique_inverse``, and
123-
``unique_values``
124-
- ``unique``
125-
- Each is equivalent to ``np.unique`` with certain flags set.
126122

127123

128124
Function instead of method

doc/source/reference/routines.set.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ Making proper sets
99
:toctree: generated/
1010

1111
unique
12+
unique_all
13+
unique_counts
14+
unique_inverse
15+
unique_values
1216

1317
Boolean operations
1418
------------------

numpy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@
208208
real_if_close, typename, mintypecode, common_type
209209
)
210210
from .lib._arraysetops_impl import (
211-
ediff1d, in1d, intersect1d, isin, setdiff1d, setxor1d, union1d, unique
211+
ediff1d, in1d, intersect1d, isin, setdiff1d, setxor1d, union1d,
212+
unique, unique_all, unique_counts, unique_inverse, unique_values
212213
)
213214
from .lib._ufunclike_impl import fix, isneginf, isposinf
214215
from .lib._arraypad_impl import pad

numpy/__init__.pyi

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,17 @@ from numpy.lib._arraypad_impl import (
406406

407407
from numpy.lib._arraysetops_impl import (
408408
ediff1d as ediff1d,
409+
in1d as in1d,
409410
intersect1d as intersect1d,
411+
isin as isin,
412+
setdiff1d as setdiff1d,
410413
setxor1d as setxor1d,
411414
union1d as union1d,
412-
setdiff1d as setdiff1d,
413415
unique as unique,
414-
in1d as in1d,
415-
isin as isin,
416+
unique_all as unique_all,
417+
unique_counts as unique_counts,
418+
unique_inverse as unique_inverse,
419+
unique_values as unique_values,
416420
)
417421

418422
from numpy.lib._function_base_impl import (

numpy/lib/_arraysetops_impl.py

Lines changed: 187 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717
import functools
1818
import warnings
19+
from typing import NamedTuple
1920

2021
import numpy as np
2122
from numpy._core import overrides
@@ -26,9 +27,10 @@
2627

2728

2829
__all__ = [
29-
'ediff1d', 'intersect1d', 'setxor1d', 'union1d', 'setdiff1d', 'unique',
30-
'in1d', 'isin'
31-
]
30+
"ediff1d", "in1d", "intersect1d", "isin", "setdiff1d", "setxor1d",
31+
"union1d", "unique", "unique_all", "unique_counts", "unique_inverse",
32+
"unique_values"
33+
]
3234

3335

3436
def _ediff1d_dispatcher(ary, to_end=None, to_begin=None):
@@ -364,6 +366,188 @@ def _unique1d(ar, return_index=False, return_inverse=False,
364366
return ret
365367

366368

369+
# Array API set functions
370+
371+
class UniqueAllResult(NamedTuple):
372+
values: np.ndarray
373+
indices: np.ndarray
374+
inverse_indices: np.ndarray
375+
counts: np.ndarray
376+
377+
378+
class UniqueCountsResult(NamedTuple):
379+
values: np.ndarray
380+
counts: np.ndarray
381+
382+
383+
class UniqueInverseResult(NamedTuple):
384+
values: np.ndarray
385+
inverse_indices: np.ndarray
386+
387+
388+
def _unique_all_dispatcher(x, /):
389+
return (x,)
390+
391+
392+
@array_function_dispatch(_unique_all_dispatcher)
393+
def unique_all(x):
394+
"""
395+
Returns the unique elements of an input array `x`, the first occurring
396+
indices for each unique element in `x`, the indices from the set of unique
397+
elements that reconstruct `x`, and the corresponding counts for each
398+
unique element in `x`.
399+
400+
This function is Array API compatible alternative to:
401+
``np.unique(x, return_index=True, return_inverse=True, return_counts=True,
402+
equal_nan=False)``
403+
404+
Parameters
405+
----------
406+
x : array_like
407+
Input array. It will be flattened if it is not already 1-D.
408+
409+
Returns
410+
-------
411+
out : namedtuple
412+
The result containing:
413+
* values - The unique elements of an input array.
414+
* indices - The first occurring indices for each unique element.
415+
* inverse_indices - The indices from the set of unique elements
416+
that reconstruct `x`.
417+
* counts - The corresponding counts for each unique element.
418+
419+
See Also
420+
--------
421+
unique : Find the unique elements of an array.
422+
423+
"""
424+
result = unique(
425+
x,
426+
return_index=True,
427+
return_inverse=True,
428+
return_counts=True,
429+
equal_nan=False
430+
)
431+
return UniqueAllResult(*result)
432+
433+
434+
def _unique_counts_dispatcher(x, /):
435+
return (x,)
436+
437+
438+
@array_function_dispatch(_unique_counts_dispatcher)
439+
def unique_counts(x):
440+
"""
441+
Returns the unique elements of an input array `x`, and the corresponding
442+
counts for each unique element in `x`.
443+
444+
This function is Array API compatible alternative to:
445+
``np.unique(x, return_counts=True, equal_nan=False)``
446+
447+
Parameters
448+
----------
449+
x : array_like
450+
Input array. It will be flattened if it is not already 1-D.
451+
452+
Returns
453+
-------
454+
out : namedtuple
455+
The result containing:
456+
* values - The unique elements of an input array.
457+
* counts - The corresponding counts for each unique element.
458+
459+
See Also
460+
--------
461+
unique : Find the unique elements of an array.
462+
463+
"""
464+
result = unique(
465+
x,
466+
return_index=False,
467+
return_inverse=False,
468+
return_counts=True,
469+
equal_nan=False
470+
)
471+
return UniqueCountsResult(*result)
472+
473+
474+
def _unique_inverse_dispatcher(x, /):
475+
return (x,)
476+
477+
478+
@array_function_dispatch(_unique_inverse_dispatcher)
479+
def unique_inverse(x):
480+
"""
481+
Returns the unique elements of an input array `x` and the indices
482+
from the set of unique elements that reconstruct `x`.
483+
484+
This function is Array API compatible alternative to:
485+
``np.unique(x, return_inverse=True, equal_nan=False)``
486+
487+
Parameters
488+
----------
489+
x : array_like
490+
Input array. It will be flattened if it is not already 1-D.
491+
492+
Returns
493+
-------
494+
out : namedtuple
495+
The result containing:
496+
* values - The unique elements of an input array.
497+
* inverse_indices - The indices from the set of unique elements
498+
that reconstruct `x`.
499+
500+
See Also
501+
--------
502+
unique : Find the unique elements of an array.
503+
504+
"""
505+
result = unique(
506+
x,
507+
return_index=False,
508+
return_inverse=True,
509+
return_counts=False,
510+
equal_nan=False
511+
)
512+
return UniqueInverseResult(*result)
513+
514+
515+
def _unique_values_dispatcher(x, /):
516+
return (x,)
517+
518+
519+
@array_function_dispatch(_unique_values_dispatcher)
520+
def unique_values(x):
521+
"""
522+
Returns the unique elements of an input array `x`.
523+
524+
This function is Array API compatible alternative to:
525+
``np.unique(x, equal_nan=False)``
526+
527+
Parameters
528+
----------
529+
x : array_like
530+
Input array. It will be flattened if it is not already 1-D.
531+
532+
Returns
533+
-------
534+
out : ndarray
535+
The unique elements of an input array.
536+
537+
See Also
538+
--------
539+
unique : Find the unique elements of an array.
540+
541+
"""
542+
return unique(
543+
x,
544+
return_index=False,
545+
return_inverse=False,
546+
return_counts=False,
547+
equal_nan=False
548+
)
549+
550+
367551
def _intersect1d_dispatcher(
368552
ar1, ar2, assume_unique=None, return_indices=None):
369553
return (ar1, ar2)

numpy/lib/_arraysetops_impl.pyi

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import (
2-
Literal as L,
32
Any,
4-
TypeVar,
3+
Generic,
4+
Literal as L,
5+
NamedTuple,
56
overload,
67
SupportsIndex,
8+
TypeVar,
79
)
810

911
import numpy as np
@@ -85,6 +87,20 @@ _SCTNoCast = TypeVar(
8587
void,
8688
)
8789

90+
class UniqueAllResult(NamedTuple, Generic[_SCT]):
91+
values: NDArray[_SCT]
92+
indices: NDArray[intp]
93+
inverse_indices: NDArray[intp]
94+
counts: NDArray[intp]
95+
96+
class UniqueCountsResult(NamedTuple, Generic[_SCT]):
97+
values: NDArray[_SCT]
98+
counts: NDArray[intp]
99+
100+
class UniqueInverseResult(NamedTuple, Generic[_SCT]):
101+
values: NDArray[_SCT]
102+
inverse_indices: NDArray[intp]
103+
88104
__all__: list[str]
89105

90106
@overload
@@ -279,6 +295,34 @@ def unique(
279295
equal_nan: bool = ...,
280296
) -> tuple[NDArray[Any], NDArray[intp], NDArray[intp], NDArray[intp]]: ...
281297

298+
@overload
299+
def unique_all(
300+
x: _ArrayLike[_SCT], /
301+
) -> UniqueAllResult[_SCT]: ...
302+
@overload
303+
def unique_all(
304+
x: ArrayLike, /
305+
) -> UniqueAllResult[Any]: ...
306+
307+
@overload
308+
def unique_counts(
309+
x: _ArrayLike[_SCT], /
310+
) -> UniqueCountsResult[_SCT]: ...
311+
@overload
312+
def unique_counts(
313+
x: ArrayLike, /
314+
) -> UniqueCountsResult[Any]: ...
315+
316+
@overload
317+
def unique_inverse(x: _ArrayLike[_SCT], /) -> UniqueInverseResult[_SCT]: ...
318+
@overload
319+
def unique_inverse(x: ArrayLike, /) -> UniqueInverseResult[Any]: ...
320+
321+
@overload
322+
def unique_values(x: _ArrayLike[_SCT], /) -> NDArray[_SCT]: ...
323+
@overload
324+
def unique_values(x: ArrayLike, /) -> NDArray[Any]: ...
325+
282326
@overload
283327
def intersect1d(
284328
ar1: _ArrayLike[_SCTNoCast],

numpy/lib/tests/test_arraysetops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,3 +919,34 @@ def test_unique_nanequals(self):
919919
not_unq = np.unique(a, equal_nan=False)
920920
assert_array_equal(unq, np.array([1, np.nan]))
921921
assert_array_equal(not_unq, np.array([1, np.nan, np.nan, np.nan]))
922+
923+
def test_unique_array_api_functions(self):
924+
arr = np.array([np.nan, 1, 4, 1, 3, 4, np.nan, 5, 1])
925+
926+
for res_unique_array_api, res_unique in [
927+
(
928+
np.unique_values(arr),
929+
np.unique(arr, equal_nan=False)
930+
),
931+
(
932+
np.unique_counts(arr),
933+
np.unique(arr, return_counts=True, equal_nan=False)
934+
),
935+
(
936+
np.unique_inverse(arr),
937+
np.unique(arr, return_inverse=True, equal_nan=False)
938+
),
939+
(
940+
np.unique_all(arr),
941+
np.unique(
942+
arr,
943+
return_index=True,
944+
return_inverse=True,
945+
return_counts=True,
946+
equal_nan=False
947+
)
948+
)
949+
]:
950+
assert len(res_unique_array_api) == len(res_unique)
951+
for actual, expected in zip(res_unique_array_api, res_unique):
952+
assert_array_equal(actual, expected)

numpy/typing/tests/data/reveal/arraysetops.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ from typing import Any
33

44
import numpy as np
55
import numpy.typing as npt
6+
from numpy.lib._arraysetops_impl import (
7+
UniqueAllResult, UniqueCountsResult, UniqueInverseResult
8+
)
69

710
if sys.version_info >= (3, 11):
811
from typing import assert_type
@@ -61,3 +64,12 @@ assert_type(np.unique(AR_f8, return_inverse=True, return_counts=True), tuple[npt
6164
assert_type(np.unique(AR_LIKE_f8, return_inverse=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp]])
6265
assert_type(np.unique(AR_f8, return_index=True, return_inverse=True, return_counts=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]])
6366
assert_type(np.unique(AR_LIKE_f8, return_index=True, return_inverse=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]])
67+
68+
assert_type(np.unique_all(AR_f8), UniqueAllResult[np.float64])
69+
assert_type(np.unique_all(AR_LIKE_f8), UniqueAllResult[Any])
70+
assert_type(np.unique_counts(AR_f8), UniqueCountsResult[np.float64])
71+
assert_type(np.unique_counts(AR_LIKE_f8), UniqueCountsResult[Any])
72+
assert_type(np.unique_inverse(AR_f8), UniqueInverseResult[np.float64])
73+
assert_type(np.unique_inverse(AR_LIKE_f8), UniqueInverseResult[Any])
74+
assert_type(np.unique_values(AR_f8), npt.NDArray[np.float64])
75+
assert_type(np.unique_values(AR_LIKE_f8), npt.NDArray[Any])

0 commit comments

Comments
 (0)
0