8000 TST: Add tests for np.argsort (#23846) · numpy/numpy@35d23ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 35d23ba

Browse files
r-devulaprkern
authored andcommitted
TST: Add tests for np.argsort (#23846)
Contributing a few tests I had used when developing AVX-512 based argsort. Co-authored-by: Robert Kern <robert.kern@gmail.com>
1 parent eaffdf3 commit 35d23ba

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

numpy/core/tests/test_multiarray.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@
4141
from datetime import timedelta, datetime
4242

4343

44+
def assert_arg_sorted(arr, arg):
45+
# resulting array should be sorted and arg values should be unique
46+
assert_equal(arr[arg], np.sort(arr))
47+
assert_equal(np.sort(arg), np.arange(len(arg)))
48+
49+
4450
def _aligned_zeros(shape, dtype=float, order="C", align=None):
4551
"""
4652
Allocate a new ndarray with aligned memory.
@@ -9989,3 +9995,39 @@ def test_sort_uint():
99899995

99909996
def test_private_get_ndarray_c_version():
99919997
assert isinstance(_get_ndarray_c_version(), int)
9998+
9999+
10000+
@pytest.mark.parametrize("N", np.arange(1, 512))
10001+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
10002+
def test_argsort_float(N, dtype):
10003+
rnd = np.random.RandomState(116112)
10004+
# (1) Regular data with a few nan: doesn't use vectorized sort
10005+
arr = -0.5 + rnd.random(N).astype(dtype)
10006+
arr[rnd.choice(arr.shape[0], 3)] = np.nan
10007+
assert_arg_sorted(arr, np.argsort(arr, kind='quick'))
10008+
10009+
# (2) Random data with inf at the end of array
10010+
# See: https://github.com/intel/x86-simd-sort/pull/39
10011+
arr = -0.5 + rnd.rand(N).astype(dtype)
10012+
arr[N-1] = np.inf
10013+
assert_arg_sorted(arr, np.argsort(arr, kind='quick'))
10014+
10015+
10016+
@pytest.mark.parametrize("N", np.arange(2, 512))
10017+
@pytest.mark.parametrize("dtype", [np.int32, np.uint32, np.int64, np.uint64])
10018+
def test_argsort_int(N, dtype):
10019+
rnd = np.random.RandomState(1100710816)
10020+
# (1) random data with min and max values
10021+
minv = np.iinfo(dtype).min
10022+
maxv = np.iinfo(dtype).max
10023+
arr = rnd.randint(low=minv, high=maxv, size=N, dtype=dtype)
10024+
i, j = rnd.choice(N, 2, replace=False)
10025+
arr[i] = minv
10026+
arr[j] = maxv
10027+
assert_arg_sorted(arr, np.argsort(arr, kind='quick'))
10028+
10029+
# (2) random data with max value at the end of array
10030+
# See: https://github.com/intel/x86-simd-sort/pull/39
10031+
arr = rnd.randint(low=minv, high=maxv, size=N, dtype=dtype)
10032+
arr[N-1] = maxv
10033+
assert_arg_sorted(arr, np.argsort(arr, kind='quick'))

0 commit comments

Comments
 (0)
0