8000 BUG: wrong selection for orders falling into equal ranges · juliantaylor/numpy@763aeea · GitHub
[go: up one dir, main page]

Skip to content

Commit 763aeea

Browse files
committed
BUG: wrong selection for orders falling into equal ranges
when orders are selected where the kth element falls into an equal range the the last stored pivot was not the kth element, this leads to losing the ordering of smaller orders as following selection steps can start at index 0 again instead of the at the offset of the last selection. Closes numpygh-4836
1 parent e715bce commit 763aeea

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

numpy/core/src/npysort/selection.c.src

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,10 @@ int
379379
/* move pivot into position */
380380
SWAP(SORTEE(low), SORTEE(hh));
381381

382-
store_pivot(hh, kth, pivots, npiv);
382+
/* kth pivot stored later */
383+
if (hh != kth) {
384+
store_pivot(hh, kth, pivots, npiv);
385+
}
383386

384387
if (hh >= kth)
385388
high = hh - 1;
@@ -389,10 +392,11 @@ int
389392

390393
/* two elements */
391394
if (high == low + 1) {
392-
if (@TYPE@_LT(v[IDX(high)], v[IDX(low)]))
395+
if (@TYPE@_LT(v[IDX(high)], v[IDX(low)])) {
393396
SWAP(SORTEE(high), SORTEE(low))
394-
store_pivot(low, kth, pivots, npiv);
397+
}
395398
}
399+
store_pivot(kth, kth, pivots, npiv);
396400

397401
return 0;
398402
}

numpy/core/tests/test_multiarray.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,12 @@ def test_partition(self):
11431143
d[i:].partition(0, kind=k)
11441144
assert_array_equal(d, tgt)
11451145

1146+
d = np.array([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1147+
7, 7, 7, 7, 7, 9])
1148+
kth = [0, 3, 19, 20]
1149+
assert_equal(np.partition(d, kth, kind=k)[kth], (0, 3, 7, 7))
1150+
assert_equal(d[np.argpartition(d, kth, kind=k)][kth], (0, 3, 7, 7))
1151+
11461152
d = np.array([2, 1])
11471153
d.partition(0, kind=k)
11481154
assert_raises(ValueError, d.partition, 2)
@@ -1332,6 +1338,18 @@ def test_partition_cdtype(self):
13321338
assert_equal(np.partition(d, k)[k], tgt[k])
13331339
assert_equal(d[np.argpartition(d, k)][k], tgt[k])
13341340

1341+
def test_partition_fuzz(self):
1342+
# a few rounds of random data testing
1343+
for j in range(10, 30):
1344+
for i in range(1, j - 2):
1345+
d = np.arange(j)
1346+
np.random.shuffle(d)
1347+
d = d % np.random.randint(2, 30)
1348+
idx = np.random.randint(d.size)
1349+
kth = [0, idx, i, i + 1]
1350+
tgt = np.sort(d)[kth]
1351+
assert_array_equal(np.partition(d, kth)[kth], tgt,
1352+
err_msg="data: %r\n kth: %r" % (d, kth))
13351353

13361354
def test_flatten(self):
13371355
x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32)

0 commit comments

Comments
 (0)
0