|
23 | 23 | 'ndim', 'nonzero', 'partition', 'prod', 'product', 'ptp', 'put',
|
24 | 24 | 'ravel', 'repeat', 'reshape', 'resize', 'round_',
|
25 | 25 | 'searchsorted', 'shape', 'size', 'sometrue', 'sort', 'squeeze',
|
26 |
| - 'std', 'sum', 'swapaxes', 'take', 'trace', 'transpose', 'var', |
| 26 | + 'std', 'sum', 'swapaxes', 'take', 'topk', 'trace', 'transpose', 'var', |
27 | 27 | ]
|
28 | 28 |
|
29 | 29 | _gentype = types.GeneratorType
|
@@ -1276,6 +1276,104 @@ def argmin(a, axis=None, out=None):
|
1276 | 1276 | return _wrapfunc(a, 'argmin', axis=axis, out=out)
|
1277 | 1277 |
|
1278 | 1278 |
|
| 1279 | +def _topk_dispatcher(a, k, axis=None, largest=None, sorted=None): |
| 1280 | + return (a,) |
| 1281 | + |
| 1282 | + |
| 1283 | +@array_function_dispatch(_topk_dispatcher) |
| 1284 | +def topk(a, k, axis=-1, largest=True, sorted=True): |
| 1285 | + """ |
| 1286 | + Finds values and indices of the `k` largest/smallest |
| 1287 | + elements along the given `axis`. |
| 1288 | +
|
| 1289 | + Parameters |
| 1290 | + ---------- |
| 1291 | + a: array_like |
| 1292 | + Array with given axis at least k. |
| 1293 | + k: int |
| 1294 | + Number of top elements to look for along the given axis. |
| 1295 | + axis: int or None, optional |
| 1296 | + Axis along which to find topk. If None, the array is flattened |
| 1297 | + before sorting. The default is -1 (the last axis). |
| 1298 | + largest: bool, optional |
| 1299 | + Controls whether to return largest or smallest elements. |
| 1300 | + sorted: bool, optional |
| 1301 | + If true the resulting k elements will be sorted by the values. |
| 1302 | +
|
| 1303 | + Returns |
| 1304 | + ------- |
| 1305 | + topk_values : ndarray |
| 1306 | + Array of values of `k` largest/smallest elements |
| 1307 | + along the specified `axis`. |
| 1308 | + topk_indices: ndarray, int |
| 1309 | + Array of indices of `k` largest/smallest elements |
| 1310 | + along the specified `axis`. |
| 1311 | +
|
| 1312 | + See Also |
| 1313 | + -------- |
| 1314 | + sort : Describes sorting algorithms used. |
| 1315 | + argsort : Indirect sort. |
| 1316 | + partition : Describes partition algorithms used. |
| 1317 | + argpartition : Indirect partial sort. |
| 1318 | + take_along_axis : Take values from the input array by |
| 1319 | + matching 1d index and data slices. |
| 1320 | +
|
| 1321 | + Examples |
| 1322 | + -------- |
| 1323 | + One dimensional array: |
| 1324 | +
|
| 1325 | + >>> x = np.array([3, 1, 2]) |
| 1326 | + >>> np.topk(x, 2) |
| 1327 | + (array([3, 2]), array([0, 2], dtype=int64)) |
| 1328 | +
|
| 1329 | + Two-dimensional array: |
| 1330 | +
|
| 1331 | + >>> x = np.array([[0, 3, 4], [2, 2, 1], [5, 1, 2]]) |
| 1332 | + >>> val, ind = np.topk(x, 2, axis=1) # along the last axis |
| 1333 | + >>> val |
| 1334 | + array([[4, 3], |
| 1335 | + [2, 2], |
| 1336 | + [5, 2]]) |
| 1337 | + >>> ind |
| 1338 | + array([[2, 1], |
| 1339 | + [0, 1], |
| 1340 | + [0, 2]]) |
| 1341 | +
|
| 1342 | + >>> val, ind = np.topk(x, 2, axis=None) # along the flattened array |
| 1343 | + >>> val |
| 1344 | + array([5, 4]) |
| 1345 | + >>> ind |
| 1346 | + array([6, 2]) |
| 1347 | +
|
| 1348 | + >>> val, ind = np.topk(x, 2, axis=0) # along the first axis |
| 1349 | + >>> val |
| 1350 | + array([[5, 3, 4], |
| 1351 | + [2, 2, 2]]) |
| 1352 | + >>> ind |
| 1353 | + array([[2, 0, 0], |
| 1354 | + [1, 1, 2]]) |
| 1355 | + """ |
| 1356 | + if largest: |
| 1357 | + index_array = np.argpartition(-a, k-1, axis=axis, order=None) |
| 1358 | + else: |
| 1359 | + index_array = np.argpartition(a, k-1, axis=axis, order=None) |
| 1360 | + topk_indices = np.take(index_array, range(k), axis=axis) |
| 1361 | + topk_values = np.take_along_axis(a, topk_indices, axis=axis) |
| 1362 | + if sorted: |
| 1363 | + if largest: |
| 1364 | + sorted_indices_in_topk = np.argsort( |
| 1365 | + -topk_values, axis=axis, order=None) |
| 1366 | + else: |
| 1367 | + sorted_indices_in_topk = np.argsort( |
| 1368 | + topk_values, axis=axis, order=None) |
| 1369 | + sorted_topk_values = np.take_along_axis( |
| 1370 | + topk_values, sorted_indices_in_topk, axis=axis) |
| 1371 | + sorted_topk_indices = np.take_along_axis( |
| 1372 | + topk_indices, sorted_indices_in_topk, axis=axis) |
| 1373 | + return sorted_topk_values, sorted_topk_indices |
| 1374 | + return topk_values, topk_indices |
| 1375 | + |
| 1376 | + |
1279 | 1377 | def _searchsorted_dispatcher(a, v, side=None, sorter=None):
|
1280 | 1378 | return (a, v, sorter)
|
1281 | 1379 |
|
|
0 commit comments