10000 [MRG+1] TST Stronger test for _check_is_permutation (#7395) · TomDLT/scikit-learn@ea167b7 · GitHub
[go: up one dir, main page]

Skip to content

Commit ea167b7

Browse files
raghavrvTomDLT
authored andcommitted
[MRG+1] TST Stronger test for _check_is_permutation (scikit-learn#7395)
* TST Stronger test for _check_is_permutation * TST Ensure additional duplicate indices are caught
1 parent 0d599cc commit ea167b7

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

sklearn/model_selection/_validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,11 +476,11 @@ def _check_is_permutation(indices, n_samples):
476476
Returns
477477
-------
478478
is_partition : bool
479-
True iff sorted(locs) is range(n)
479+
True iff sorted(indices) is np.arange(n)
480480
"""
481481
if len(indices) != n_samples:
482482
return False
483-
hit = np.zeros(n_samples, bool)
483+
hit = np.zeros(n_samples, dtype=bool)
484484
hit[indices] = True
485485
if not np.all(hit):
486486
return False

sklearn/model_selection/tests/test_validation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,13 +731,18 @@ def test_validation_curve():
731731

732732

733733
def test_check_is_permutation():
734+
rng = np.random.RandomState(0)
734735
p = np.arange(100)
736+
rng.shuffle(p)
735737
assert_true(_check_is_permutation(p, 100))
736738
assert_false(_check_is_permutation(np.delete(p, 23), 100))
737739

738740
p[0] = 23
739741
assert_false(_check_is_permutation(p, 100))
740742

743+
# Check if the additional duplicate indices are caught
744+
assert_false(_check_is_permutation(np.hstack((p, 0)), 100))
745+
741746

742747
def test_cross_val_predict_sparse_prediction():
743748
# check that cross_val_predict gives same result for sparse and dense input

0 commit comments

Comments
 (0)
0