8000 EHN: add numpy.topk · numpy/numpy@d9ef740 · GitHub
[go: up one dir, main page]

Skip to content

Commit d9ef740

Browse files
committed
EHN: add numpy.topk
1 parent 89da723 commit d9ef740

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

numpy/core/fromnumeric.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
'ndim', 'nonzero', 'partition', 'prod', 'product', 'ptp', 'put',
2424
'ravel', 'repeat', 'reshape', 'resize', 'round_',
2525
'searchsorted', 'shape', 'size', 'sometrue', 'sort', 'squeeze',
26-
'std', 'sum', 'swapaxes', 'take', 'trace', 'transpose', 'var',
26+
'std', 'sum', 'swapaxes', 'take', 'topk', 'trace', 'transpose', 'var',
2727
]
2828

2929
_gentype = types.GeneratorType
@@ -1276,6 +1276,104 @@ def argmin(a, axis=None, out=None):
12761276
return _wrapfunc(a, 'argmin', axis=axis, out=out)
12771277

12781278

1279+
def _topk_dispatcher(a, k, axis=None, largest=None, sorted=None):
1280+
return (a,)
1281+
1282+
1283+
@array_function_dispatch(_topk_dispatcher)
1284+
def topk(a, k, axis=-1, largest=True, sorted=True):
1285+
"""
1286+
Finds values and indices of the `k` largest/smallest
1287+
elements along the given `axis`.
1288+
1289+
Parameters
1290+
----------
1291+
a: array_like
1292+
Array with given axis at least k.
1293+
k: int
1294+
Number of top elements to look for along the given axis.
1295+
axis: int or None, optional
1296+
Axis along which to find topk. If None, the array is flattened
1297+
before sorting. The default is -1 (the last axis).
1298+
largest: bool, optional
1299+
Controls whether to return largest or smallest elements.
1300+
sorted: bool, optional
1301+
If true the resulting k elements will be sorted by the values.
1302+
1303+
Returns
1304+
-------
1305+
topk_values : ndarray
1306+
Array of values of `k` largest/smallest elements
1307+
along the specified `axis`.
1308+
topk_indices: ndarray, int
1309+
Array of indices of `k` largest/smallest elements
1310+
along the specified `axis`.
1311+
1312+
See Also
1313+
--------
1314+
sort : Describes sorting algorithms used.
1315+
argsort : Indirect sort.
1316+
partition : Describes partition algorithms used.
1317+
argpartition : Indirect partial sort.
1318+
take_along_axis : Take values from the input array by
1319+
matching 1d index and data slices.
1320+
1321+
Examples
1322+
--------
1323+
One dimensional array:
1324+
1325+
>>> x = np.array([3, 1, 2])
1326+
>>> np.topk(x, 2)
1327+
(array([3, 2]), array([0, 2], dtype=int64))
1328+
1329+
Two-dimensional array:
1330+
1331+
>>> x = np.array([[0, 3, 4], [2, 2, 1], [5, 1, 2]])
1332+
>>> val, ind = np.topk(x, 2, axis=1) # along the last axis
1333+
>>> val
1334+
array([[4, 3],
1335+
[2, 2],
1336+
[5, 2]])
1337+
>>> ind
1338+
array([[2, 1],
1339+
[0, 1],
1340+
[0, 2]])
1341+
1342+
>>> val, ind = np.topk(x, 2, axis=None) # along the flattened array
1343+
>>> val
1344+
array([5, 4])
1345+
>>> ind
1346+
array([6, 2])
1347+
1348+
>>> val, ind = np.topk(x, 2, axis=0) # along the first axis
1349+
>>> val
1350+
array([[5, 3, 4],
1351+
[2, 2, 2]])
1352+
>>> ind
1353+
array([[2, 0, 0],
1354+
[1, 1, 2]])
1355+
"""
1356+
if largest:
1357+
index_array = np.argpartition(-a, k-1, axis=axis, order=None)
1358+
else:
1359+
index_array = np.argpartition(a, k-1, axis=axis, order=None)
1360+
topk_indices = np.take(index_array, range(k), axis=axis)
1361+
topk_values = np.take_along_axis(a, topk_indices, axis=axis)
1362+
if sorted:
1363+
if largest:
1364+
sorted_indices_in_topk = np.argsort(
1365+
-topk_values, axis=axis, order=None)
1366+
else:
1367+
sorted_indices_in_topk = np.argsort(
1368+
topk_values, axis=axis, order=None)
1369+
sorted_topk_values = np.take_along_axis(
1370+
topk_values, sorted_indices_in_topk, axis=axis)
1371+
sorted_topk_indices = np.take_along_axis(
1372+
topk_indices, sorted_indices_in_topk, axis=axis)
1373+
return sorted_topk_values, sorted_topk_indices
1374+
return topk_values, topk_indices
1375+
1376+
12791377
def _searchsorted_dispatcher(a, v, side=None, sorter=None):
12801378
return (a, v, sorter)
12811379

numpy/core/fromnumeric.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ def argmin(
151151
out: Optional[ndarray] = ...,
152152
) -> Any: ...
153153

154+
def topk(
155+
a: ArrayLike,
156+
k: Optional[int] = ...,
157+
axis: Optional[int] = ...,
158+
largest: Optional[bool] = ...,
159+
sorted: Optional[bool] = ...,
160+
) ->< 4C79 /span> Tuple[ndarray, ndarray]: ...
161+
154162
@overload
155163
def searchsorted(
156164
a: ArrayLike,

0 commit comments

Comments
 (0)
0