8000 feat: add API specification for returning the `k` largest elements by kgryte · Pull Request #722 · data-apis/array-api · GitHub
[go: up one dir, main page]

Skip to content

feat: add API specification for returning the k largest elements #722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'main' into feat/top_k
  • Loading branch information
kgryte authored Jan 25, 2024
commit e5d31892a6d4d3f88b42db629e1d4944d3ae6d55
1 change: 1 addition & 0 deletions spec/draft/API_specification/searching_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Objects in API
argmax
argmin
nonzero
searchsorted
top_k
top_k_indices
top_k_values
Expand Down
55 changes: 53 additions & 2 deletions src/array_api_stubs/_draft/searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"argmax",
"argmin",
"nonzero",
"searchsorted",
"top_k",
"top_k_values",
"top_k_indices",
Expand Down Expand Up @@ -95,6 +96,56 @@ def nonzero(x: array, /) -> Tuple[array, ...]:
"""


def searchsorted(
x1: array,
x2: array,
/,
*,
side: Literal["left", "right"] = "left",
sorter: Optional[array] = None,
) -> array:
"""
Finds the indices into ``x1`` such that, if the corresponding elements in ``x2`` were inserted before the indices, the order of ``x1``, when sorted in ascending order, would be preserved.

Parameters
----------
x1: array
input array. Must be a one-dimensional array. Should have a real-valued data type. If ``sorter`` is ``None``, must be sorted in ascending order; otherwise, ``sorter`` must be an array of indices that sort ``x1`` in ascending order.
x2: array
array containing search values. Should have a real-valued data type.
side: Literal['left', 'right']
argument controlling which index is returned if a value lands exactly on an edge.

Let ``x`` be an array of rank ``N`` where ``v`` is an individual element given by ``v = x2[n,m,...,j]``.

If ``side == 'left'``, then

- each returned index ``i`` must satisfy the index condition ``x1[i-1] < v <= x1[i]``.
- if no index satisfies the index condition, then the returned index for that element must be ``0``.

Otherwise, if ``side == 'right'``, then

- each returned index ``i`` must satisfy the index condition ``x1[i-1] <= v < x1[i]``.
- if no index satisfies the index condition, then the returned index for that element must be ``N``, where ``N`` is the number of elements in ``x1``.

Default: ``'left'``.
sorter: Optional[array]
array of indices that sort ``x1`` in ascending order. The array must have the same shape as ``x1`` and have an integer data type. Default: ``None``.

Returns
-------
out: array
an array of indices with the same shape as ``x2``. The returned array must have the default array index data type.

Notes
-----

For real-valued floating-point arrays, the sort order of NaNs and signed zeros is unspecified and thus implementation-dependent. Accordingly, when a real-valued floating-point array contains NaNs and signed zeros, what constitutes ascending order may vary among specification-conforming array libraries.

While behavior for arrays containing NaNs and signed zeros is implementation-dependent, specification-conforming libraries should, however, ensure consistency with ``sort`` and ``argsort`` (i.e., if a value in ``x2`` is inserted into ``x1`` according to the corresponding index in the output array and ``sort`` is invoked on the resultant array, the sorted result should be an array in the same order).
"""


def top_k(
x: array,
k: int,
Expand Down Expand Up @@ -206,7 +257,7 @@ def top_k_values(
- ``'smallest'``: return the indices of the ``k`` smallest elements.

Default: ``'largest'``.

Returns
-------
out: array
Expand All @@ -220,7 +271,7 @@ def top_k_values(
- Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`).
"""


def where(condition: array, x1: array, x2: array, /) -> array:
"""
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.
0