8000 WIP: top_k draft implementation · JuliaPoo/numpy@5984fe6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5984fe6

Browse files
committed
WIP: top_k draft implementation
Following previous discussion at numpy#15128. I made a small change to the interface in the previous discussion by changing the `mode` keyword into a `largest` bool flag. This follows API such as from [torch.topk](https://pytorch.org/docs/stable/generated/torch.topk.html). Carrying from the previous discussion, a parameter might be useful is `sorted`. This is also implemented in `torch.topk`, and follows from previous work at numpy#19117. Co-authored-by: quarrying
1 parent 92412e9 commit 5984fe6

File tree

5 files changed

+160
-2
lines changed

5 files changed

+160
-2
lines changed

numpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@
163163
shares_memory, short, sign, signbit, signedinteger, sin, single, sinh,
164164
size, sort, spacing, sqrt, square, squeeze, stack, std,
165165
str_, subtract, sum, swapaxes, take, tan, tanh, tensordot,
166-
timedelta64, trace, transpose, true_divide, trunc, typecodes, ubyte,
166+
timedelta64, top_k, trace, transpose, true_divide, trunc, typecodes, ubyte,
167167
ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong,
168168
ulonglong, unsignedinteger, ushort, var, vdot, vecdot, void, vstack,
169169
where, zeros, zeros_like

numpy/_core/fromnumeric.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
'ndim', 'nonzero', 'partition', 'prod', 'ptp', 'put',
2727
'ravel', 'repeat', 'reshape', 'resize', 'round',
2828
'searchsorted', 'shape', 'size', 'sort', 'squeeze',
29-
'std', 'sum', 'swapaxes', 'take', 'trace', 'transpose', 'var',
29+
'std', 'sum', 'swapaxes', 'take', 'top_k', 'trace',
30+
'transpose', 'var',
3031
]
3132

3233
_gentype = types.GeneratorType
@@ -206,6 +207,100 @@ def take(a, indices, axis=None, out=None, mode='raise'):
206207
return _wrapfunc(a, 'take', indices, axis=axis, out=out, mode=mode)
207208

208209

210+
def _top_k_dispatcher(a, k, /, *, axis=-1, largest=True):
211+
return (a,)
212+
213+
214+
@array_function_dispatch(_top_k_dispatcher)
215+
def top_k(a, k, /, *, axis=-1, largest=True):
216+
"""
217+
Returns the ``k`` largest/smallest elements and corresponding
218+
indices along the given ``axis``.
219+
220+
When ``axis`` is None, a flattened array is used.
221+
222+
If ``largest`` is false, then the ``k`` smallest elements are returned.
223+
224+
A tuple of ``(values, indices)`` is returned, where ``values`` and
225+
``indices`` of the largest/smallest elements of each row of the input
226+
array in the given ``axis``.
227+
228+
Parameters
229+
----------
230+
a: array_like
231+
The source array
232+
k: int
233+
The number of largest/smallest elements to return. ``k`` must
234+
be a positive integer and within indexable range specified by
235+
``axis``.
236+
axis: int, optional
237+
Axis along which to find the largest/smallest elements.
238+
The default is -1 (the last axis).
239+
If None, a flattened array is used.
240+
largest: bool, optional
241+
If True, largest elements are returned. Otherwise the smallest
242+
are returned.
243+
244+
Returns
245+
-------
246+
tuple_of_array: tuple
247+
The output tuple of ``(topk_values, topk_indices)``, where
248+
``topk_values`` are returned elements from the source array
249+
(not necessarily in sorted order), and ``topk_indices`` are
250+
the corresponding indices.
251+
252+
See Also
253+
--------
254+
argpartition : Indirect partition.
255+
sort : Full sorting.
256+
257+
Examples
258+
--------
259+
>>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]])
260+
>>> np.top_k(a, 2)
261+
(array([[4, 5],
262+
[4, 5],
263+
[4, 5]]),
264+
array([[3, 4],
265+
[1, 0],
266+
[1, 2]]))
267+
>>> np.top_k(a, 2, axis=0)
268+
(array([[3, 4, 3, 2, 2],
269+
[5, 4, 5, 4, 5]]),
270+
array([[2, 1, 1, 1, 2],
271+
[1, 2, 2, 0, 0]]))
272+
>>> a.flatten()
273+
array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2])
274+
>>> np.top_k(a, 2, axis=None)
275+
(array([5, 5]), array([ 5, 12]))
276+
"""
277+
if k <= 0:
278+
raise ValueError(f'k(={k}) provided must be positive.')
279+
280+
positive_axis: int
281+
_arr = np.asanyarray(a)
282+
if axis is None:
283+
arr = _arr.ravel()
284+
positive_axis = 0
285+
else:
286+
arr = _arr
287+
positive_axis = axis if axis > 0 else axis % arr.ndim
288+
289+
slice_start = (np.s_[:],) * positive_axis
290+
if largest:
291+
indices_array = np.argpartition(arr, -k, axis=axis)
292+
slice = slice_start + (np.s_[-k:],)
293+
topk_indices = indices_array[slice]
294+
else:
295+
indices_array = np.argpartition(arr, k-1, axis=axis)
296+
slice = slice_start + (np.s_[:k],)
297+
topk_indices = indices_array[slice]
298+
299+
topk_values = np.take_along_axis(arr, topk_indices, axis=axis)
300+
301+
return (topk_values, topk_indices)
302+
303+
209304
def _reshape_dispatcher(a, /, shape=None, *, newshape=None, order=None,
210305
copy=None):
211306
return (a,)

numpy/_core/fromnumeric.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ def take(
8989
mode: _ModeKind = ...,
9090
) -> _ArrayType: ...
9191

92+
def top_k(
93+
a: ArrayLike,
94+
k: int,
95+
/,
96+
*,
97+
axis: None | int = ...,
98+
largest: bool = ...,
99+
) -> tuple[NDArray[Any], NDArray[intp]]: ...
100+
92101
@overload
93102
def reshape(
94103
a: _ArrayLike[_SCT],

numpy/_core/tests/test_multiarray.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3176,6 +3176,54 @@ def test_argpartition_gh5524(self, kth_dtype):
31763176
p = np.argpartition(d, kth)
31773177
self.assert_partitioned(np.array(d)[p],[1])
31783178

3179+
def assert_top_k(self, a, axis: int, x, y):
3180+
x_value, x_indices = x
3181+
y_value, y_indices = y
3182+
assert_equal(np.sort(x_value, axis=axis), np.sort(y_value, axis=axis))
3183+
assert_equal(np.sort(x_indices, axis=axis), np.sort(y_indices, axis=axis))
3184+
assert_equal(np.take_along_axis(a, x_indices, axis=axis), x_value)
3185+
3186+
def test_top_k(self):
3187+
3188+
a = np.array([
3189+
[1, 2, 3, 4, 5],
3190+
[5, 4, 2, 3, 1],
3191+
[3, 5, 4, 1, 2]
3192+
], dtype=np.int8)
3193+
3194+
with assert_raises_regex(
3195+
ValueError,
3196+
r"k\(=0\) provided must be positive."
3197+
):
3198+
np.top_k(a, 0)
3199+
3200+
y = (
3201+
np.array([[4, 5], [4, 5], [4, 5]], dtype=np.int8),
3202+
np.array([[3, 4], [0, 1], [1, 2]], dtype=np.intp)
3203+
)
3204+
self.assert_top_k(a, -1, np.top_k(a, 2), y)
3205+
self.assert_top_k(a, 1, np.top_k(a, 2), y)
3206+
3207+
axis = 0
3208+
y = (
3209+
np.array([[5, 4, 3, 4, 5],
3210+
[3, 5, 4, 3, 2]], dtype=np.int8),
3211+
np.array([[1, 1, 0, 0, 0],
3212+
[2, 2, 2, 1, 2]], dtype=np.int8)
3213+
)
3214+
self.assert_top_k(a, axis, np.top_k(a, 2, axis=axis), y)
3215+
3216+
y = (
3217+
np.array([[1, 2], [1, 2], [1, 2]], dtype=np.int8),
3218+
np.array([[0, 1], [2, 4], [3, 4]], dtype=np.intp)
3219+
)
3220+
self.assert_top_k(a, -1, np.top_k(a, 2, largest=False), y)
3221+
self.assert_top_k(a, 1, np.top_k(a, 2, largest=False), y)
3222+
3223+
y_val, y_ind = np.top_k(a, 2, axis=None)
3224+
assert_equal(y_val, np.array([5, 5], dtype=np.int8))
3225+
assert_equal(np.take_along_axis(a.ravel(), y_ind, axis=-1), y_val)
3226+
31793227
def test_flatten(self):
31803228
x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
31813229
x1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], np.int32)

numpy/_core/tests/test_numeric.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,12 @@ def test_take(self):
336336
assert_equal(out, tgt)
337337
assert_equal(out.dtype, tgt.dtype)
338338

339+
def test_top_k(self):
340+
a = [[1, 2], [2, 1]]
341+
y = ([[2], [2]], [[1], [0]])
342+
out = np.top_k(a, 1)
343+
assert_equal(out, y)
344+
339345
def test_trace(self):
340346
c = [[1, 2], [3, 4], [5, 6]]
341347
assert_equal(np.trace(c), 5)

0 commit comments

Comments
 (0)
0