8000 MNT Use isinstance(..., float/numbers.Integral) · scikit-learn/scikit-learn@01e776e · GitHub
[go: up one dir, main page]

Skip to content

Commit 01e776e

Browse files
raghavrvJoan Massich
authored and
Joan Massich
committed
MNT Use isinstance(..., float/numbers.Integral)
1 parent 8e599c6 commit 01e776e

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

sklearn/model_selection/_split.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ class _BaseKFold(with_metaclass(ABCMeta, BaseCrossValidator)):
271271

272272
@abstractmethod
273273
def __init__(self, n_splits, shuffle, random_state):
274-
if not isinstance(n_splits, numbers.Integral):
274+
if not isinstance(n_splits, (numbers.Integral, np.integer)):
275275
raise ValueError('The number of folds must be of Integral type. '
276276
'%s of type %s was passed.'
277277
% (n_splits, type(n_splits)))
@@ -1643,27 +1643,27 @@ def _validate_shuffle_split_init(test_size, train_size):
16431643
raise ValueError('test_size and train_size can not both be None')
16441644

16451645
if test_size is not None:
1646-
if np.asarray(test_size).dtype.kind == 'f':
1646+
if isinstance(test_size, (float, np.floating)):
16471647
if test_size >= 1.:
16481648
raise ValueError(
16491649
'test_size=%f should be smaller '
16501650
'than 1.0 or be an integer' % test_size)
1651-
elif np.asarray(test_size).dtype.kind != 'i':
1651+
elif not isinstance(test_size, (numbers.Integral, np.integer)):
16521652
# int values are checked during split based on the input
16531653
raise ValueError("Invalid value for test_size: %r" % test_size)
16541654

16551655
if train_size is not None:
1656-
if np.asarray(train_size).dtype.kind == 'f':
1656+
if isinstance(train_size, (float, np.floating)):
16571657
if train_size >= 1.:
16581658
raise ValueError("train_size=%f should be smaller "
16591659
"than 1.0 or be an integer" % train_size)
1660-
elif (np.asarray(test_size).dtype.kind == 'f' and
1660+
elif (isinstance(test_size, (float, np.floating)) and
16611661
(train_size + test_size) > 1.):
16621662
raise ValueError('The sum of test_size and train_size = %f, '
16631663
'should be smaller than 1.0. Reduce '
16641664
'test_size and/or train_size.' %
16651665
(train_size + test_size))
1666-
elif np.asarray(train_size).dtype.kind != 'i':
1666+
elif not isinstance(train_size, (numbers.Integral, np.integer)):
16671667
# int values are checked during split based on the input
16681668
raise ValueError("Invalid value for train_size: %r" % train_size)
16691669

@@ -1672,30 +1672,32 @@ def _validate_shuffle_split(n_samples, test_size, train_size):
16721672
"""
16731673
Validation helper to check if the test/test sizes are meaningful wrt to the
16741674
size of the data (n_samples)
1675+
1676+
test_size, defaults 8000 to 0.1
16751677
"""
16761678
if (test_size is not None and
1677-
np.asarray(test_size).dtype.kind == 'i' and
1679+
isinstance(test_size, (numbers.Integral, np.integer)) and
16781680
test_size >= n_samples):
16791681
raise ValueError('test_size=%d should be smaller than the number of '
16801682
'samples %d' % (test_size, n_samples))
16811683

16821684
if (train_size is not None and
1683-
np.asarray(train_size).dtype.kind == 'i' and
1685+
isinstance(train_size, (numbers.Integral, np.integer)) and
16841686
train_size >= n_samples):
16851687
raise ValueError("train_size=%d should be smaller than the number of"
16861688
" samples %d" % (train_size, n_samples))
16871689

16881690
if test_size == "default":
16891691
test_size = 0.1
16901692

1691-
if np.asarray(test_size).dtype.kind == 'f':
1693+
if isinstance(test_size, (float, np.floating)):
16921694
n_test = ceil(test_size * n_samples)
1693-
elif np.asarray(test_size).dtype.kind == 'i':
1695+
elif isinstance(test_size, (numbers.Integral, np.integer)):
16941696
n_test = float(test_size)
16951697

16961698
if train_size is None:
16971699
n_train = n_samples - n_test
1698-
elif np.asarray(train_size).dtype.kind == 'f':
1700+
elif isinstance(train_size, (float, np.floating)):
16991701
n_train = floor(train_size * n_samples)
17001702
else:
17011703
n_train = float(train_size)
@@ -1900,7 +1902,7 @@ def check_cv(cv=3, y=None, classifier=False):
19001902
if cv is None:
19011903
cv = 3
19021904

1903-
if isinstance(cv, numbers.Integral):
1905+
if isinstance(cv, (numbers.Integral, np.integer)):
19041906
if (classifier and (y is not None) and
19051907
(type_of_target(y) in ('binary', 'multiclass'))):
19061908
return StratifiedKFold(cv)

sklearn/model_selection/tests/test_split.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,15 +543,20 @@ def test_kfold_can_detect_dependent_samples_on_digits(): # see #2372
543543

544544

545545
def test_shuffle_split():
546+
# Use numpy float as input
547+
ss0 = ShuffleSplit(test_size=np.float16(0.2), random_state=0).split(X)
546548
ss1 = ShuffleSplit(test_size=0.2, random_state=0).split(X)
547549
ss2 = ShuffleSplit(test_size=2, random_state=0).split(X)
550+
# Use numpy int as input
548551
ss3 = ShuffleSplit(test_size=np.int32(2), random_state=0).split(X)
549552
for typ in six.integer_types:
550553
ss4 = ShuffleSplit(test_size=typ(2), random_state=0).split(X)
551-
for t1, t2, t3, t4 in zip(ss1, ss2, ss3, ss4):
554+
for t0, t1, t2, t3, t4 in zip(ss0, ss1, ss2, ss3, ss4):
555+
assert_array_equal(t0[0], t1[0])
552556
assert_array_equal(t1[0], t2[0])
553557
assert_array_equal(t2[0], t3[0])
554558
assert_array_equal(t3[0], t4[0])
559+
assert_array_equal(t0[1], t1[1])
555560
assert_array_equal(t1[1], t2[1])
556561
assert_array_equal(t2[1], t3[1])
557562
assert_array_equal(t3[1], t4[1])

0 commit comments

Comments
 (0)
0