|
41 | 41 | from datetime import timedelta, datetime
|
42 | 42 |
|
43 | 43 |
|
| 44 | +def assert_arg_sorted(arr, arg): |
| 45 | + # resulting array should be sorted and arg values should be unique |
| 46 | + assert_equal(arr[arg], np.sort(arr)) |
| 47 | + assert_equal(np.sort(arg), np.arange(len(arg))) |
| 48 | + |
| 49 | + |
44 | 50 | def _aligned_zeros(shape, dtype=float, order="C", align=None):
|
45 | 51 | """
|
46 | 52 | Allocate a new ndarray with aligned memory.
|
@@ -9989,3 +9995,39 @@ def test_sort_uint():
|
9989 | 9995 |
|
9990 | 9996 | def test_private_get_ndarray_c_version():
|
9991 | 9997 | assert isinstance(_get_ndarray_c_version(), int)
|
| 9998 | + |
| 9999 | + |
| 10000 | +@pytest.mark.parametrize("N", np.arange(1, 512)) |
| 10001 | +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) |
| 10002 | +def test_argsort_float(N, dtype): |
| 10003 | + rnd = np.random.RandomState(116112) |
| 10004 | + # (1) Regular data with a few nan: doesn't use vectorized sort |
| 10005 | + arr = -0.5 + rnd.random(N).astype(dtype) |
| 10006 | + arr[rnd.choice(arr.shape[0], 3)] = np.nan |
| 10007 | + assert_arg_sorted(arr, np.argsort(arr, kind='quick')) |
| 10008 | + |
| 10009 | + # (2) Random data with inf at the end of array |
| 10010 | + # See: https://github.com/intel/x86-simd-sort/pull/39 |
| 10011 | + arr = -0.5 + rnd.rand(N).astype(dtype) |
| 10012 | + arr[N-1] = np.inf |
| 10013 | + assert_arg_sorted(arr, np.argsort(arr, kind='quick')) |
| 10014 | + |
| 10015 | + |
| 10016 | +@pytest.mark.parametrize("N", np.arange(2, 512)) |
| 10017 | +@pytest.mark.parametrize("dtype", [np.int32, np.uint32, np.int64, np.uint64]) |
| 10018 | +def test_argsort_int(N, dtype): |
| 10019 | + rnd = np.random.RandomState(1100710816) |
| 10020 | + # (1) random data with min and max values |
| 10021 | + minv = np.iinfo(dtype).min |
| 10022 | + maxv = np.iinfo(dtype).max |
| 10023 | + arr = rnd.randint(low=minv, high=maxv, size=N, dtype=dtype) |
| 10024 | + i, j = rnd.choice(N, 2, replace=False) |
| 10025 | + arr[i] = minv |
| 10026 | + arr[j] = maxv |
| 10027 | + assert_arg_sorted(arr, np.argsort(arr, kind='quick')) |
| 10028 | + |
| 10029 | + # (2) random data with max value at the end of array |
| 10030 | + # See: https://github.com/intel/x86-simd-sort/pull/39 |
| 10031 | + arr = rnd.randint(low=minv, high=maxv, size=N, dtype=dtype) |
| 10032 | + arr[N-1] = maxv |
| 10033 | + assert_arg_sorted(arr, np.argsort(arr, kind='quick')) |
0 commit comments