8000 BUG: wrong selection for orders falling into equal ranges · juliantaylor/numpy@d6c7a16 · GitHub < 10000 meta name="ui-target" content="full">
[go: up one dir, main page]

Skip to content

Commit d6c7a16

Browse files
committed
BUG: wrong selection for orders falling into equal ranges
when orders are selected where the kth element falls into an equal range the the last stored pivot was not the kth element, this leads to losing the ordering of smaller orders as following selection steps can start at index 0 again instead of the at the offset of the last selection. Closes numpygh-4836
1 parent e8d1374 commit d6c7a16

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

numpy/core/src/npysort/selection.c.src

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,10 @@ int
390390
/* move pivot into position */
391391
SWAP(SORTEE(low), SORTEE(hh));
392392

393-
store_pivot(hh, kth, pivots, npiv);
393+
/* kth pivot stored later */
394+
if (hh != kth) {
395+
store_pivot(hh, kth, pivots, npiv);
396+
}
394397

395398
if (hh >= kth)
396399
high = hh - 1;
@@ -400,10 +403,11 @@ int
400403

401404
/* two elements */
402405
if (high == low + 1) {
403-
if (@TYPE@_LT(v[IDX(high)], v[IDX(low)]))
406+
if (@TYPE@_LT(v[IDX(high)], v[IDX(low)])) {
404407
SWAP(SORTEE(high), SORTEE(low))
405-
store_pivot(low, kth, pivots, npiv);
408+
}
406409
}
410+
store_pivot(kth, kth, pivots, npiv);
407411

408412
return 0;
409413
}

numpy/core/tests/test_multiarray.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,6 +1356,12 @@ def test_partition(self):
13561356
d[i:].partition(0, kind=k)
13571357
assert_array_equal(d, tgt)
13581358

1359+
d = np.array([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1360+
7, 7, 7, 7, 7, 9])
1361+
kth = [0, 3, 19, 20]
1362+
assert_equal(np.partition(d, kth, kind=k)[kth], (0, 3, 7, 7))
1363+
assert_equal(d[np.argpartition(d, kth, kind=k)][kth], (0, 3, 7, 7))
1364+
13591365
d = np.array([2, 1])
13601366
d.partition(0, kind=k)
13611367
assert_raises(ValueError, d.partition, 2)
@@ -1551,6 +1557,18 @@ def test_partition_unicode_kind(self):
15511557
assert_raises(ValueError, d.partition, 2, kind=k)
15521558
assert_raises(ValueError, d.argpartition, 2, kind=k)
15531559

1560+
def test_partition_fuzz(self):
1561+
# a few rounds of random data testing
1562+
for j in range(10, 30):
1563+
for i in range(1, j - 2):
1564+
d = np.arange(j)
1565+
np.random.shuffle(d)
1566+
d = d % np.random.randint(2, 30)
1567+
idx = np.random.randint(d.size)
1568+
kth = [0, idx, i, i + 1]
1569+
tgt = np.sort(d)[kth]
1570+
assert_array_equal(np.partition(d, kth)[kth], tgt,
1571+
err_msg="data: %r\n kth: %r" % (d, kth))
15541572

15551573
def test_flatten(self):
15561574
x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32)

0 commit comments

Comments
 (0)
0