10000 Regression test for stable descending `array_api.argsort()` · numpy/numpy@0e81997 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0e81997

Browse files
committed
Regression test for stable descending array_api.argsort()
1 parent 007d347 commit 0e81997

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
from numpy import array_api as xp
4+
5+
6+
@pytest.mark.parametrize(
7+
"obj, axis, expected",
8+
[
9+
([0, 0], -1, [0, 1]),
10+
([0, 1, 0], -1, [1, 0, 2]),
11+
([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]),
12+
([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]),
13+
],
14+
)
15+
def test_stable_desc_argsort(obj, axis, expected):
16+
"""
17+
Indices respect relative order of a descending stable-sort
18+
19+
See https://github.com/numpy/numpy/issues/20778
20+
"""
21+
x = xp.asarray(obj)
22+
out = xp.argsort(x, axis=axis, stable=True, descending=True)
23+
assert xp.all(out == xp.asarray(expected, dtype=out.dtype))

0 commit comments

Comments
 (0)
0