8000 API: Add Array API setops [Array API] by mtsokol · Pull Request #25088 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

API: Add Array API setops [Array API] #25088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/release/upcoming_changes/25088.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Array API set functions
-----------------------

`numpy.unique_all`, `numpy.unique_counts`, `numpy.unique_inverse`,
and `numpy.unique_values` functions have been added for Array API compatiblity.
They provide functionality of `numpy.unique` with different sets of flags.
4 changes: 0 additions & 4 deletions doc/source/reference/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,6 @@ The following functions are named differently in the array API
* - ``pow``
- ``power``
-
* - ``unique_all``, ``unique_counts``, ``unique_inverse``, and
``unique_values``
- ``unique``
- Each is equivalent to ``np.unique`` with certain flags set.


Function instead of method
Expand Down
4 changes: 4 additions & 0 deletions doc/source/reference/routines.set.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Making proper sets
:toctree: generated/

unique
unique_all
unique_counts
unique_inverse
unique_values

Boolean operations
------------------
Expand Down
3 changes: 2 additions & 1 deletion numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@
real_if_close, typename, mintypecode, common_type
)
from .lib._arraysetops_impl import (
ediff1d, in1d, intersect1d, isin, setdiff1d, setxor1d, union1d, unique
ediff1d, in1d, intersect1d, isin, setdiff1d, setxor1d, union1d,
unique, unique_all, unique_counts, unique_inverse, unique_values
)
from .lib._ufunclike_impl import fix, isneginf, isposinf
from .lib._arraypad_impl import pad
Expand Down
10 changes: 7 additions & 3 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,17 @@ from numpy.lib._arraypad_impl import (

from numpy.lib._arraysetops_impl import (
ediff1d as ediff1d,
in1d as in1d,
intersect1d as intersect1d,
isin as isin,
setdiff1d as setdiff1d,
setxor1d as setxor1d,
union1d as union1d,
setdiff1d as setdiff1d,
unique as unique,
in1d as in1d,
isin as isin,
unique_all as unique_all,
unique_counts as unique_counts,
unique_inverse as unique_inverse,
unique_values as unique_values,
)

from numpy.lib._function_base_impl import (
Expand Down
190 changes: 187 additions & 3 deletions numpy/lib/_arraysetops_impl.py
9E7A
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
import functools
import warnings
from typing import NamedTuple

import numpy as np
from numpy._core import overrides
Expand All @@ -26,9 +27,10 @@


__all__ = [
'ediff1d', 'intersect1d', 'setxor1d', 'union1d', 'setdiff1d', 'unique',
'in1d', 'isin'
]
"ediff1d", "in1d", "intersect1d", "isin", "setdiff1d", "setxor1d",
"union1d", "unique", "unique_all", "unique_counts", "unique_inverse",
"unique_values"
]


def _ediff1d_dispatcher(ary, to_end=None, to_begin=None):
Expand Down Expand Up @@ -364,6 +366,188 @@ def _unique1d(ar, return_index=False, return_inverse=False,
return ret


# Array API set functions

class UniqueAllResult(NamedTuple):
values: np.ndarray
indices: np.ndarray
inverse_indices: np.ndarray
counts: np.ndarray


class UniqueCountsResult(NamedTuple):
values: np.ndarray
counts: np.ndarray


class UniqueInverseResult(NamedTuple):
values: np.ndarray
inverse_indices: np.ndarray


def _unique_all_dispatcher(x, /):
return (x,)


@array_function_dispatch(_unique_all_dispatcher)
def unique_all(x):
"""
Returns the unique elements of an input array `x`, the first occurring
indices for each unique element in `x`, the indices from the set of unique
elements that reconstruct `x`, and the corresponding counts for each
unique element in `x`.

This function is Array API compatible alternative to:
``np.unique(x, return_index=True, return_inverse=True, return_counts=True,
equal_nan=False)``

Parameters
----------
x : array_like
Input array. It will be flattened if it is not already 1-D.

Returns
-------
out : namedtuple
The result containing:
* values - The unique elements of an input array.
* indices - The first occurring indices for each unique element.
* inverse_indices - The indices from the set of unique elements
that reconstruct `x`.
* counts - The corresponding counts for each unique element.

See Also
--------
unique : Find the unique elements of an array.

"""
result = unique(
x,
return_index=True,
return_inverse=True,
return_counts=True,
equal_nan=False
)
return UniqueAllResult(*result)


def _unique_counts_dispatcher(x, /):
return (x,)


@array_function_dispatch(_unique_counts_dispatcher)
def unique_counts(x):
"""
Returns the unique elements of an input array `x`, and the corresponding
counts for each unique element in `x`.

This function is Array API compatible alternative to:
``np.unique(x, return_counts=True, equal_nan=False)``

Parameters
----------
x : array_like
Input array. It will be flattened if it is not already 1-D.

Returns
-------
out : namedtuple
The result containing:
* values - The unique elements of an input array.
* counts - The corresponding counts for each unique element.

See Also
--------
unique : Find the unique elements of an array.

"""
result = unique(
x,
return_index=False,
return_inverse=False,
return_counts=True,
equal_nan=False
)
return UniqueCountsResult(*result)


def _unique_inverse_dispatcher(x, /):
return (x,)


@array_function_dispatch(_unique_inverse_dispatcher)
def unique_inverse(x):
"""
Returns the unique elements of an input array `x` and the indices
from the set of unique elements that reconstruct `x`.

This function is Array API compatible alternative to:
``np.unique(x, return_inverse=True, equal_nan=False)``

Parameters
----------
x : array_like
Input array. It will be flattened if it is not already 1-D.

Returns
-------
out : namedtuple
The result containing:
* values - The unique elements of an input array.
* inverse_indices - The indices from the set of unique elements
that reconstruct `x`.

See Also
--------
unique : Find the unique elements of an array.

"""
result = unique(
x,
return_index=False,
return_inverse=True,
return_counts=False,
equal_nan=False
)
return UniqueInverseResult(*result)


def _unique_values_dispatcher(x, /):
return (x,)


@array_function_dispatch(_unique_values_dispatcher)
def unique_values(x):
"""
Returns the unique elements of an input array `x`.

This function is Array API compatible alternative to:
``np.unique(x, equal_nan=False)``

Parameters
----------
x : array_like
Input array. It will be flattened if it is not already 1-D.

Returns
-------
out : ndarray
The unique elements of an input array.

See Also
--------
unique : Find the unique elements of an array.

"""
return unique(
x,
return_index=False,
return_inverse=False,
return_counts=False,
equal_nan=False
)


def _intersect1d_dispatcher(
ar1, ar2, assume_unique=None, return_indices=None):
return (ar1, ar2)
Expand Down
48 changes: 46 additions & 2 deletions numpy/lib/_arraysetops_impl.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import (
Literal as L,
Any,
TypeVar,
Generic,
Literal as L,
NamedTuple,
overload,
SupportsIndex,
TypeVar,
)

import numpy as np
Expand Down Expand Up @@ -85,6 +87,20 @@ _SCTNoCast = TypeVar(
void,
)

class UniqueAllResult(NamedTuple, Generic[_SCT]):
values: NDArray[_SCT]
indices: NDArray[intp]
inverse_indices: NDArray[intp]
counts: NDArray[intp]

class UniqueCountsResult(NamedTuple, Generic[_SCT]):
values: NDArray[_SCT]
counts: NDArray[intp]

class UniqueInverseResult(NamedTuple, Generic[_SCT]):
values: NDArray[_SCT]
inverse_indices: NDArray[intp]

__all__: list[str]

@overload
Expand Down Expand Up @@ -279,6 +295,34 @@ def unique(
equal_nan: bool = ...,
) -> tuple[NDArray[Any], NDArray[intp], NDArray[intp], NDArray[intp]]: ...

@overload
def unique_all(
x: _ArrayLike[_SCT], /
) -> UniqueAllResult[_SCT]: ...
@overload
def unique_all(
x: ArrayLike, /
) -> UniqueAllResult[Any]: ...

@overload
def unique_counts(
x: _ArrayLike[_SCT], /
) -> UniqueCountsResult[_SCT]: ...
@overload
def unique_counts(
x: ArrayLike, /
) -> UniqueCountsResult[Any]: ...

@overload
def unique_inverse(x: _ArrayLike[_SCT], /) -> UniqueInverseResult[_SCT]: ...
@overload
def unique_inverse(x: ArrayLike, /) -> UniqueInverseResult[Any]: ...

@overload
def unique_values(x: _ArrayLike[_SCT], /) -> NDArray[_SCT]: ...
@overload
def unique_values(x: ArrayLike, /) -> NDArray[Any]: ...

@overload
def intersect1d(
ar1: _ArrayLike[_SCTNoCast],
Expand Down
31 changes: 31 additions & 0 deletions numpy/lib/tests/test_arraysetops.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,3 +919,34 @@ def test_unique_nanequals(self):
not_unq = np.unique(a, equal_nan=False)
assert_array_equal(unq, np.array([1, np.nan]))
assert_array_equal(not_unq, np.array([1, np.nan, np.nan, np.nan]))

def test_unique_array_api_functions(self):
arr = np.array([np.nan, 1, 4, 1, 3, 4, np.nan, 5, 1])

for res_unique_array_api, res_unique in [
(
np.unique_values(arr),
np.unique(arr, equal_nan=False)
),
(
np.unique_counts(arr),
np.unique(arr, return_counts=True, equal_nan=False)
),
(
np.unique_inverse(arr),
np.unique(arr, return_inverse=True, equal_nan=False)
),
(
np.unique_all(arr),
np.unique(
arr,
return_index=True,
return_inverse=True,
return_counts=True,
equal_nan=False
)
)
]:
assert len(res_unique_array_api) == len(res_unique)
for actual, expected in zip(res_unique_array_api, res_unique):
assert_array_equal(actual, expected)
12 changes: 12 additions & 0 deletions numpy/typing/tests/data/reveal/arraysetops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ from typing import Any

import numpy as np
import numpy.typing as npt
from numpy.lib._arraysetops_impl import (
UniqueAllResult, UniqueCountsResult, UniqueInverseResult
)

if sys.version_info >= (3, 11):
from typing import assert_type
Expand Down Expand Up @@ -61,3 +64,12 @@ assert_type(np.unique(AR_f8, return_inverse=True, return_counts=True), tuple[npt
assert_type(np.unique(AR_LIKE_f8, return_inverse=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp]])
assert_type(np.unique(AR_f8, return_index=True, return_inverse=True, return_counts=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]])
assert_type(np.unique(AR_LIKE_f8, return_index=True, return_inverse=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]])

assert_type(np.unique_all(AR_f8), UniqueAllResult[np.float64])
assert_type(np.unique_all(AR_LIKE_f8), UniqueAllResult[Any])
assert_type(np.unique_counts(AR_f8), UniqueCountsResult[np.float64])
assert_type(np.unique_counts(AR_LIKE_f8), UniqueCountsResult[Any])
assert_type(np.unique_inverse(AR_f8), UniqueInverseResult[np.float64])
assert_type(np.unique_inverse(AR_LIKE_f8), UniqueInverseResult[Any])
assert_type(np.unique_values(AR_f8), npt.NDArray[np.float64])
assert_type(np.unique_values(AR_LIKE_f8), npt.NDArray[Any])
Loading
0