8000 API: Add Array API setops · numpy/numpy@4327973 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4327973

Browse files
committed
API: Add Array API setops
1 parent 9340fca commit 4327973

File tree

9 files changed

+295
-13
lines changed
  • typing/tests/data/reveal
  • 9 files changed

    +295
    -13
    lines changed
    Lines changed: 6 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -0,0 +1,6 @@
    1+
    Array API set functions
    2+
    -----------------------
    3+
    4+
    `numpy.unique_all`, `numpy.unique_counts`, `numpy.unique_inverse`,
    5+
    and `numpy.unique_values` functions have been added for Array API compatiblity.
    6+
    They provide functionality of `numpy.unique` with different sets of flags.

    doc/source/reference/array_api.rst

    Lines changed: 0 additions & 4 deletions
    Original file line numberDiff line numberDiff line change
    @@ -119,10 +119,6 @@ The following functions are named differently in the array API
    119119
    * - ``pow``
    120120
    - ``power``
    121121
    -
    122-
    * - ``unique_all``, ``unique_counts``, ``unique_inverse``, and
    123-
    ``unique_values``
    124-
    - ``unique``
    125-
    - Each is equivalent to ``np.unique`` with certain flags set.
    126122

    127123

    128124
    Function instead of method

    doc/source/reference/routines.set.rst

    Lines changed: 4 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -9,6 +9,10 @@ Making proper sets
    99
    :toctree: generated/
    1010

    1111
    unique
    12+
    unique_all
    13+
    unique_counts
    14+
    unique_inverse
    15+
    unique_values
    1216

    1317
    Boolean operations
    1418
    ------------------

    numpy/__init__.py

    Lines changed: 2 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -207,7 +207,8 @@
    207207
    real_if_close, typename, mintypecode, common_type
    208208
    )
    209209
    from .lib._arraysetops_impl import (
    210-
    ediff1d, in1d, intersect1d, isin, setdiff1d, setxor1d, union1d, unique
    210+
    ediff1d, in1d, intersect1d, isin, setdiff1d, setxor1d, union1d,
    211+
    unique, unique_all, unique_counts, unique_inverse, unique_values
    211212
    )
    212213
    from .lib._ufunclike_impl import fix, isneginf, isposinf
    213214
    from .lib._arraypad_impl import pad

    numpy/__init__.pyi

    Lines changed: 7 additions & 3 deletions
    Original file line numberDiff line numberDiff line change
    @@ -405,13 +405,17 @@ from numpy.lib._arraypad_impl import (
    405405

    406406
    from numpy.lib._arraysetops_impl import (
    407407
    ediff1d as ediff1d,
    408+
    in1d as in1d,
    408409
    intersect1d as intersect1d,
    410+
    isin as isin,
    411+
    setdiff1d as setdiff1d,
    409412
    setxor1d as setxor1d,
    410413
    union1d as union1d,
    411-
    setdiff1d as setdiff1d,
    412414
    unique as unique,
    413-
    in1d as in1d,
    414-
    isin as isin,
    415+
    unique_all as unique_all,
    416+
    unique_counts as unique_counts,
    417+
    unique_inverse as unique_inverse,
    418+
    unique_values as unique_values,
    415419
    )
    416420

    417421
    from numpy.lib._function_base_impl import (

    numpy/lib/_arraysetops_impl.py

    Lines changed: 187 additions & 3 deletions
    Original file line numberDiff line numberDiff line change
    @@ -16,6 +16,7 @@
    1616
    """
    1717
    import functools
    1818
    import warnings
    19+
    from typing import NamedTuple
    1920

    2021
    import numpy as np
    2122
    from numpy._core import overrides
    @@ -26,9 +27,10 @@
    2627

    2728

    2829
    __all__ = [
    29-
    'ediff1d', 'intersect1d', 'setxor1d', 'union1d', 'setdiff1d', 'unique',
    30-
    'in1d', 'isin'
    31-
    ]
    30+
    "ediff1d", "in1d", "intersect1d", "isin", "setdiff1d", "setxor1d",
    31+
    "union1d", "unique", "unique_all", "unique_counts", "unique_inverse",
    32+
    "unique_values"
    33+
    ]
    3234

    3335

    3436
    def _ediff1d_dispatcher(ary, to_end=None, to_begin=None):
    @@ -364,6 +366,188 @@ def _unique1d(ar, return_index=False, return_inverse=False,
    364366
    return ret
    365367

    366368

    369+
    # Array API set functions
    370+
    371+
    class UniqueAllResult(NamedTuple):
    372+
    values: np.ndarray
    373+
    indices: np.ndarray
    374+
    inverse_indices: np.ndarray
    375+
    counts: np.ndarray
    376+
    377+
    378+
    class UniqueCountsResult(NamedTuple):
    379+
    values: np.ndarray
    380+
    counts: np.ndarray
    381+
    382+
    383+
    class UniqueInverseResult(NamedTuple):
    384+
    values: np.ndarray
    385+
    inverse_indices: np.ndarray
    386+
    387+
    388+
    def _unique_all_dispatcher(x, /):
    389+
    return (x,)
    390+
    391+
    392+
    @array_function_dispatch(_unique_all_dispatcher)
    393+
    def unique_all(x):
    394+
    """
    395+
    Returns the unique elements of an input array `x`, the first occurring
    396+
    indices for each unique element in `x`, the indices from the set of unique
    397+
    elements that reconstruct `x`, and the corresponding counts for each
    398+
    unique element in `x`.
    399+
    400+
    This function is Array API compatible alternative to:
    401+
    ``np.unique(x, return_index=True, return_inverse=True, return_counts=True,
    402+
    equal_nan=False)``
    403+
    404+
    Parameters
    405+
    ----------
    406+
    x : array_like
    407+
    Input array. It will be flattened if it is not already 1-D.
    408+
    409+
    Returns
    410+
    -------
    411+
    out : namedtuple
    412+
    The result containing:
    413+
    * values - The unique elements of an input array.
    414+
    * indices - The first occurring indices for each unique element.
    415+
    * inverse_indices - The indices from the set of unique elements
    416+
    that reconstruct `x`.
    417+
    * counts - The corresponding counts for each unique element.
    418+
    419+
    See Also
    420+
    --------
    421+
    unique : Find the unique elements of an array.
    422+
    423+
    """
    424+
    result = unique(
    425+
    x,
    426+
    return_index=True,
    427+
    return_inverse=True,
    428+
    return_counts=True,
    429+
    equal_nan=False
    430+
    )
    431+
    return UniqueAllResult(*result)
    432+
    433+
    434+
    def _unique_counts_dispatcher(x, /):
    435+
    return (x,)
    436+
    437+
    438+
    @array_function_dispatch(_unique_counts_dispatcher)
    439+
    def unique_counts(x):
    440+
    """
    441+
    Returns the unique elements of an input array `x`, and the corresponding
    442+
    counts for each unique element in `x`.
    443+
    444+
    This function is Array API compatible alternative to:
    445+
    ``np.unique(x, return_counts=True, equal_nan=False)``
    446+
    447+
    Parameters
    448+
    ----------
    449+
    x : array_like
    450+
    Input array. It will be flattened if it is not already 1-D.
    451+
    452+
    Returns
    453+
    -------
    454+
    out : namedtuple
    455+
    The result containing:
    456+
    * values - The unique elements of an input array.
    457+
    * counts - The corresponding counts for each unique element.
    458+
    459+
    See Also
    460+
    --------
    461+
    unique : Find the unique elements of an array.
    462+
    463+
    """
    464+
    result = unique(
    465+
    x,
    466+
    return_index=False,
    467+
    return_inverse=False,
    468+
    return_counts=True,
    469+
    equal_nan=False
    470+
    )
    471+
    return UniqueCountsResult(*result)
    472+
    473+
    474+
    def _unique_inverse_dispatcher(x, /):
    475+
    return (x,)
    476+
    477+
    478+
    @array_function_dispatch(_unique_inverse_dispatcher)
    479+
    def unique_inverse(x):
    480+
    """
    481+
    Returns the unique elements of an input array `x` and the indices
    482+
    from the set of unique elements that reconstruct `x`.
    483+
    484+
    This function is Array API compatible alternative to:
    485+
    ``np.unique(x, return_inverse=True, equal_nan=False)``
    486+
    487+
    Parameters
    488+
    ----------
    489+
    x : array_like
    490+
    Input array. It will be flattened if it is not already 1-D.
    491+
    492+
    Returns
    493+
    -------
    494+
    out : namedtuple
    495+
    The result containing:
    496+
    * values - The unique elements of an input array.
    497+
    * inverse_indices - The indices from the set of unique elements
    498+
    that reconstruct `x`.
    499+
    500+
    See Also
    501+
    --------
    502+
    unique : Find the unique elements of an array.
    503+
    504+
    """
    505+
    result = unique(
    506+
    x,
    507+
    return_index=False,
    508+
    return_inverse=True,
    509+
    return_counts=False,
    510+
    equal_nan=False
    511+
    )
    512+
    return UniqueInverseResult(*result)
    513+
    514+
    515+
    def _unique_values_dispatcher(x, /):
    516+
    return (x,)
    517+
    518+
    519+
    @array_function_dispatch(_unique_values_dispatcher)
    520+
    def unique_values(x):
    521+
    """
    522+
    Returns the unique elements of an input array `x`.
    523+
    524+
    This function is Array API compatible alternative to:
    525+
    ``np.unique(x, equal_nan=False)``
    526+
    527+
    Parameters
    528+
    ----------
    529+
    x : array_like
    530+
    Input array. It will be flattened if it is not already 1-D.
    531+
    532+
    Returns
    533+
    -------
    534+
    out : ndarray
    535+
    The unique elements of an input array.
    536+
    537+
    See Also
    538+
    --------
    539+
    unique : Find the unique elements of an array.
    540+
    541+
    """
    542+
    return unique(
    543+
    x,
    544+
    return_index=False,
    545+
    return_inverse=False,
    546+
    return_counts=False,
    547+
    equal_nan=False
    548+
    )
    549+
    550+
    367551
    def _intersect1d_dispatcher(
    368552
    ar1, ar2, assume_unique=None, return_indices=None):
    369553
    return (ar1, ar2)

    numpy/lib/_arraysetops_impl.pyi

    Lines changed: 46 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1,9 +1,11 @@
    11
    from typing import (
    2-
    Literal as L,
    32
    Any,
    4-
    TypeVar,
    3+
    Generic,
    4+
    Literal as L,
    5+
    NamedTuple,
    56
    overload,
    67
    SupportsIndex,
    8+
    TypeVar,
    79
    )
    810

    911
    from numpy import (
    @@ -85,6 +87,20 @@ _SCTNoCast = TypeVar(
    8587
    void,
    8688
    )
    8789

    90+
    class UniqueAllResult(NamedTuple, Generic[_SCT]):
    91+
    values: NDArray[_SCT]
    92+
    indices: NDArray[intp]
    93+
    inverse_indices: NDArray[intp]
    94+
    counts: NDArray[intp]
    95+
    96+
    class UniqueCountsResult(NamedTuple, Generic[_SCT]):
    97+
    values: NDArray[_SCT]
    98+
    counts: NDArray[intp]
    99+
    100+
    class UniqueInverseResult(NamedTuple, Generic[_SCT]):
    101+
    values: NDArray[_SCT]
    102+
    inverse_indices: NDArray[intp]
    103+
    88104
    __all__: list[str]
    89105

    90106
    @overload
    @@ -279,6 +295,34 @@ def unique(
    279295
    equal_nan: bool = ...,
    280296
    ) -> tuple[NDArray[Any], NDArray[intp], NDArray[intp], NDArray[intp]]: ...
    281297

    298+
    @overload
    299+
    def unique_all(
    300+
    x: _ArrayLike[_SCT], /
    301+
    ) -> UniqueAllResult[_SCT]: ...
    302+
    @overload
    303+
    def unique_all(
    304+
    x: ArrayLike, /
    305+
    ) -> UniqueAllResult[Any]: ...
    306+
    307+
    @overload
    308+
    def unique_counts(
    309+
    x: _ArrayLike[_SCT], /
    310+
    ) -> UniqueCountsResult[_SCT]: ...
    311+
    @overload
    312+
    def unique_counts(
    313+
    x: ArrayLike, /
    314+
    ) -> UniqueCountsResult[Any]: ...
    315+
    316+
    @overload
    317+
    def unique_inverse(x: _ArrayLike[_SCT], /) -> UniqueInverseResult[_SCT]: ...
    318+
    @overload
    319+
    def unique_inverse(x: ArrayLike, /) -> UniqueInverseResult[Any]: ...
    320+
    321+
    @overload
    322+
    def unique_values(x: _ArrayLike[_SCT], /) -> NDArray[_SCT]: ...
    323+
    @overload
    324+
    def unique_values(x: ArrayLike, /) -> NDArray[Any]: ...
    325+
    282326
    @overload
    283327
    def intersect1d(
    284328
    ar1: _ArrayLike[_SCTNoCast],

    numpy/lib/tests/test_arraysetops.py

    Lines changed: 31 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -919,3 +919,34 @@ def test_unique_nanequals(self):
    919919
    not_unq = np.unique(a, equal_nan=False)
    920920
    assert_array_equal(unq, np.array([1, np.nan]))
    921921
    assert_array_equal(not_unq, np.array([1, np.nan, np.nan, np.nan]))
    922+
    923+
    def test_unique_array_api_functions(self):
    924+
    arr = np.array([np.nan, 1, 4, 1, 3, 4, np.nan, 5, 1])
    925+
    926+
    for res_unique_array_api, res_unique in [
    927+
    (
    928+
    np.unique_values(arr),
    929+
    np.unique(arr, equal_nan=False)
    930+
    ),
    931+
    (
    932+
    np.unique_counts(arr),
    933+
    np.unique(arr, return_counts=True, equal_nan=False)
    934+
    ),
    935+
    (
    936+
    np.unique_inverse(arr),
    937+
    np.unique(arr, return_inverse=True, equal_nan=False)
    938+
    ),
    939+
    (
    940+
    np.unique_all(arr),
    941+
    np.unique(
    942+
    arr,
    943+
    return_index=True,
    944+
    return_inverse=True,
    945+
    return_counts=True,
    946+
    equal_nan=False
    947+
    )
    948+
    )
    949+
    ]:
    950+
    assert len(res_unique_array_api) == len(res_unique)
    951+
    for actual, expected in zip(res_unique_array_api, res_unique):
    952+
    assert_array_equal(actual, expected)

    numpy/typing/tests/data/reveal/arraysetops.pyi

    Lines changed: 12 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -3,6 +3,9 @@ from typing import Any
    33

    44
    import nump 741A y as np
    55
    import numpy.typing as npt
    6+
    from numpy.lib._arraysetops_impl import (
    7+
    UniqueAllResult, UniqueCountsResult, UniqueInverseResult
    8+
    )
    69

    710
    if sys.version_info >= (3, 11):
    811
    from typing import assert_type
    @@ -61,3 +64,12 @@ assert_type(np.unique(AR_f8, return_inverse=True, return_counts=True), tuple[npt
    6164
    assert_type(np.unique(AR_LIKE_f8, return_inverse=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp]])
    6265
    assert_type(np.unique(AR_f8, return_index=True, return_inverse=True, return_counts=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]])
    6366
    assert_type(np.unique(AR_LIKE_f8, return_index=True, return_inverse=True, return_counts=True), tuple[npt.NDArray[Any], npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.intp]])
    67+
    68+
    assert_type(np.unique_all(AR_f8), UniqueAllResult[np.float64])
    69+
    assert_type(np.unique_all(AR_LIKE_f8), UniqueAllResult[Any])
    70+
    assert_type(np.unique_counts(AR_f8), UniqueCountsResult[np.float64])
    71+
    assert_type(np.unique_counts(AR_LIKE_f8), UniqueCountsResult[Any])
    72+
    assert_type(np.unique_inverse(AR_f8), UniqueInverseResult[np.float64])
    73+
    assert_type(np.unique_inverse(AR_LIKE_f8), 3FBE UniqueInverseResult[Any])
    74+
    assert_type(np.unique_values(AR_f8), npt.NDArray[np.float64])
    75+
    assert_type(np.unique_values(AR_LIKE_f8), npt.NDArray[Any])

    0 commit comments

    Comments
     (0)
    0