8000 API: Add Array API setops · numpy/numpy@04233a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 04233a6

Browse files
committed
API: Add Array API setops
1 parent 29cbb1f commit 04233a6

File tree

7 files changed

+165
-9
lines changed

7 files changed

+165
-9
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Array API set functions
2+
-----------------------
3+
4+
`numpy.unique_all`, `numpy.unique_counts`, `numpy.unique_inverse`,
5+
and `numpy.unique_values` functions have been added for Array API compatiblity.
6+
They provide functionality of `numpy.unique` with different sets of flags.

doc/source/reference/routines.set.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ Making proper sets
99
:toctree: generated/
1010

1111
unique
12+
unique_all
13+
unique_counts
14+
unique_inverse
15+
unique_values
1216

1317
Boolean operations
1418
------------------

numpy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@
207207
real_if_close, typename, mintypecode, common_type
208208
)
209209
from .lib._arraysetops_impl import (
210-
ediff1d, in1d, intersect1d, isin, setdiff1d, setxor1d, union1d, unique
210+
ediff1d, in1d, intersect1d, isin, setdiff1d, setxor1d, union1d,
211+
unique, unique_all, unique_counts, unique_inverse, unique_values
211212
)
212213
from .lib._ufunclike_impl import fix, isneginf, isposinf
213214
from .lib._arraypad_impl import pad

numpy/__init__.pyi

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,17 @@ from numpy.lib._arraypad_impl import (
405405

406406
from numpy.lib._arraysetops_impl import (
407407
ediff1d as ediff1d,
408+
in1d as in1d,
408409
intersect1d as intersect1d,
410+
isin as isin,
411+
setdiff1d as setdiff1d,
409412
setxor1d as setxor1d,
410413
union1d as union1d,
411-
setdiff1d as setdiff1d,
412414
unique as unique,
413-
in1d as in1d,
414-
isin as isin,
415+
unique_all as unique_all,
416+
unique_counts as unique_counts,
417+
unique_inverse as unique_inverse,
418+
unique_values as unique_values,
415419
)
416420

417421
from numpy.lib._function_base_impl import (

numpy/lib/_arraysetops_impl.py

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
6D40 1717
import functools
1818
import warnings
19+
from typing import NamedTuple
1920

2021
import numpy as np
2122
from numpy._core import overrides
@@ -26,9 +27,10 @@
2627

2728

2829
__all__ = [
29-
'ediff1d', 'intersect1d', 'setxor1d', 'union1d', 'setdiff1d', 'unique',
30-
'in1d', 'isin'
31-
]
30+
"ediff1d", "in1d", "intersect1d", "isin", "setdiff1d", "setxor1d",
31+
"union1d", "unique", "unique_all", "unique_counts", "unique_inverse",
32+
"unique_values"
33+
]
3234

3335

3436
def _ediff1d_dispatcher(ary, to_end=None, to_begin=None):
@@ -364,6 +366,101 @@ def _unique1d(ar, return_index=False, return_inverse=False,
364366
return ret
365367

366368

369+
# Array API set functions
370+
371+
class UniqueAllResult(NamedTuple):
372+
values: np.ndarray
373+
indices: np.ndarray
374+
inverse_indices: np.ndarray
375+
counts: np.ndarray
376+
377+
378+
class UniqueCountsResult(NamedTuple):
379+
values: np.ndarray
380+
counts: np.ndarray
381+
382+
383+
class UniqueInverseResult(NamedTuple):
384+
values: np.ndarray
385+
inverse_indices: np.ndarray
386+
387+
388+
UniqueAllResult.__module__ = "numpy"
389+
UniqueCountsResult.__module__ = "numpy"
390+
UniqueInverseResult.__module__ = "numpy"
391+
392+
393+
def _unique_all_dispatcher(x, /):
394+
return (x,)
395+
396+
397+
@array_function_dispatch(_unique_all_dispatcher)
398+
def unique_all(x):
399+
"""
400+
"""
401+
result = unique(
402+
x,
403+
return_index=True,
404+
return_inverse=True,
405+
return_counts=True,
406+
equal_nan=False
407+
)
408+
return UniqueAllResult(*result)
409+
410+
411+
def _unique_counts_dispatcher(x, /):
412+
return (x,)
413+
414+
415+
@array_function_dispatch(_unique_counts_dispatcher)
416+
def unique_counts(x):
417+
"""
418+
"""
419+
result = unique(
420+
x,
421+
return_index=False,
422+
return_inverse=False,
423+
return_counts=True,
424+
equal_nan=False
425+
)
426+
return UniqueCountsResult(*result)
427+
428+
429+
def _unique_inverse_dispatcher(x, /):
430+
return (x,)
431+
432+
433+
@array_function_dispatch(_unique_inverse_dispatcher)
434+
def unique_inverse(x):
435+
"""
436+
"""
437+
result = unique(
438+
x,
439+
return_index=False,
440+
return_inverse=True,
441+
return_counts=False,
442+
equal_nan=False
443+
)
444+
return UniqueInverseResult(*result)
445+
446+
447+
def _unique_values_dispatcher(x, /):
448+
return (x,)
449+
450+
451+
@array_function_dispatch(_unique_values_dispatcher)
452+
def unique_values(x):
453+
"""
454+
"""
455+
return unique(
456+
x,
457+
return_index=False,
458+
return_inverse=False,
459+
return_counts=False,
460+
equal_nan=False
461+
)
462+
463+
367464
def _intersect1d_dispatcher(
368465
ar1, ar2, assume_unique=None, return_indices=None):
369466
return (ar1, ar2)

numpy/lib/_arraysetops_impl.pyi

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import (
2-
Literal as L,
32
Any,
4-
TypeVar,
3+
Literal as L,
4+
NamedTuple,
55
overload,
66
SupportsIndex,
7+
TypeVar,
78
)
89

910
from numpy import (
@@ -279,6 +280,18 @@ def unique(
279280
equal_nan: bool = ...,
280281
) -> tuple[NDArray[Any], NDArray[intp], NDArray[intp], NDArray[intp]]: ...
281282

283+
def unique_all(
284+
x: ArrayLike, /
285+
) -> tuple[NDArray[Any], NDArray[intp], NDArray[intp], NDArray[intp]]: ...
286+
287+
def unique_counts(
288+
x: ArrayLike, /
289+
) -> tuple[NDArray[Any], NDArray[intp]]: ...
290+
291+
def unique_inverse(x: ArrayLike, /) -> tuple[NDArray[Any], NDArray[intp]]: ...
292+
293+
def unique_values(x: ArrayLike, /) -> NDArray[Any]: ...
294+
282295
@overload
283296
def intersect1d(
284297
ar1: _ArrayLike[_SCTNoCast],

numpy/lib/tests/test_arraysetops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,3 +919,34 @@ def test_unique_nanequals(self):
919919
not_unq = np.unique(a, equal_nan=False)
920920
assert_array_equal(unq, np.array([1, np.nan]))
921921
assert_array_equal(not_unq, np.array([1, np.nan, np.nan, np.nan]))
922+
923+
def test_unique_array_api_functions(self):
924+
arr = np.array([np.nan, 1, 4, 1, 3, 4, np.nan, 5, 1])
925+
926+
for res_unique_array_api, res_unique in [
927+
(
928+
np.unique_values(arr),
929+
np.unique(arr, equal_nan=False)
930+
),
931+
(
932+
np.unique_counts(arr),
933+
np.unique(arr, return_counts=True, equal_nan=False)
934+
),
935+
(
936+
np.unique_inverse(arr),
937+
np.unique(arr, return_inverse=True, equal_nan=False)
938+
),
939+
(
940+
np.unique_all(arr),
941+
np.unique(
942+
arr,
943+
return_index=True,
944+
return_inverse=True,
945+
return_counts=True,
946+
equal_nan=False
947+
)
948+
)
949+
]:
950+
assert len(res_unique_array_api) == len(res_unique)
951+
for actual, expected in zip(res_unique_array_api, res_unique):
952+
assert_array_equal(actual, expected)

0 commit comments

Comments
 (0)
0