8000 Add top_k comparisons · data-apis/array-api-comparison@8d0aaf0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8d0aaf0

Browse files
committed
Add top_k comparisons
1 parent 20880a5 commit 8d0aaf0

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

signatures/searching/top_k.md

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# top_k
2+
3+
## NumPy
4+
5+
```
6+
numpy.topk(a, k, axis=-1, largest=True, sorted=True) → [ndarray, ndarray]
7+
```
8+
9+
**Note**: this is not present in NumPy, but is proposed in <https://github.com/numpy/numpy/pull/19117>.
10+
11+
```
12+
numpy.partition(a, kth, axis=-1, kind='introselect', order=None) → ndarray
13+
```
14+
15+
**Note**: returns an array of the same shape as `a` and requires sorting the return value to get the top `k` values in order.
16+
17+
```
18+
numpy.argpartition(a, kth, axis=-1, kind='introselect', order=None) → ndarray
19+
```
20+
21+
**Note**: returns an array of the same shape as `a` and requires sorting after using the return value to get the top `k` values.
22+
23+
## CuPy
24+
25+
```
26+
cupy.argpartition(a, kth, axis=-1) → ndarray
27+
```
28+
29+
**Note**: performs a full sort.
30+
31+
## dask.array
32+
33+
```
34+
dask.array.topk(a, k, axis=-1, split_every=None) → ndarray
35+
```
36+
37+
**Note**: only returns values. If `k` is negative, returns the smallest `k` values. Returned values are sorted.
38+
39+
```
40+
dask.array.argtopk(a, k, axis=-1, split_every=None)
41+
```
42+
43+
**Note**: only returns indices. If `k` is negative, returns the indices for the smallest `k` values. Returned indices correspond to sorted values.
44+
45+
## JAX
46+
47+
```
48+
jax.lax.top_k(operand, k) → ndarray
49+
```
50+
51+
**Note**: only returns values.
52+
53+
```
54+
jax.numpy.partition(a, kth, axis=-1) → ndarray
55+
```
56+
57+
**Note**: implemented via two calls to `jax.lax.top_k`. Differs from NumPy in handling of NaN values, where NaN values which have negative sign bits are sorted to the beginning of the array.
58+
59+
```
60+
jax.numpy.argpartition(a, kth, axis=-1) → ndarray
61+
```
62+
63+
**Note**: implemented via two calls to `jax.lax.top_k`. Differs from NumPy in handling of NaN values, where NaN values which have negative sign bits are sorted to the beginning of the array.
64+
65+
## MXNet
66+
67+
```
68+
npx.topk(data, axis=-1, k=1, ret_typ='indices', is_ascend=False, dtype='float32') → ndarray | [ndarray, ndarray]
69+
```
70+
71+
**Note**: whether a single ndarray or a list of ndarrays is returned is determined by `ret_type`.
72+
73+
## PyTorch
74+
75+
```
76+
torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) → (Tensor, LongTensor)
77+
```
78+
79+
**Note**: returns a named tuple containing values and indices. Differs from NumPy et al for default `dim`.
80+
81+
## TensorFlow
82+
83+
```
84+
tf.math.top_k(input, k=1, sorted=True, index_type=tf.dtypes.int32, name=None
85+
) → (Tensor, Tensor)
86+
```
87+
88+
**Note**: returns a `(values, indices)` tuple. Only supports last axis.

0 commit comments

Comments
 (0)
0