8000 FIX all float/invalid type errors at init and int error at split · raghavrv/scikit-learn@789a316 · GitHub
[go: up one dir, main page]

Skip to content

Commit 789a316

Browse files
committed
FIX all float/invalid type errors at init and int error at split
1 parent e4d41ad commit 789a316

File tree

2 files changed

+36
-34
lines changed

2 files changed

+36
-34
lines changed

sklearn/model_selection/split.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -744,45 +744,43 @@ def _validate_shuffle_split_init(test_size, train_size):
744744
if test_size is None and train_size is None:
745745
raise ValueError('test_size and train_size can not both be None')
746746

747-
if (test_size is not None) and (np.asarray(test_size).dtype.kind == 'f'):
748-
if test_size >= 1.:
749-
raise ValueError(
750-
'test_size=%f should be smaller '
751-
'than 1.0 or be an integer' % test_size)
752-
753-
if (train_size is not None) and (np.asarray(train_size).dtype.kind == 'f'):
754-
if train_size >= 1.:
755-
raise ValueError("train_size=%f should be smaller "
756-
"than 1.0 or be an integer" % train_size)
757-
elif np.asarray(test_size).dtype.kind == 'f' and \
758-
train_size + test_size > 1.:
759-
raise ValueError('The sum of test_size and train_size = %f, '
760-
'should be smaller than 1.0. Reduce '
761-
'test_size and/or train_size.' %
762-
(train_size + test_size))
763-
764-
765-
def _validate_shuffle_split(n, test_size, train_size):
766747
if test_size is not None:
767-
if np.asarray(test_size).dtype.kind == 'i':
768-
if test_size >= n:
748+
if np.asarray(test_size).dtype.kind == 'f':
749+
if test_size >= 1.:
769750
raise ValueError(
770-
'test_size=%d should be smaller '
771-
'than the number of samples %d' % (test_size, n))
772-
elif np.asarray(test_size).dtype.kind != 'f':
773-
# Float values are checked during __init__
751+
'test_size=%f should be smaller '
752+
'than 1.0 or be an integer' % test_size)
753+
elif np.asarray(test_size).dtype.kind != 'i':
754+
# int values are checked during split based on the input
774755
raise ValueError("Invalid value for test_size: %r" % test_size)
775756

776757
if train_size is not None:
777-
if np.asarray(train_size).dtype.kind == 'i':
778-
if train_size >= n:
779-
raise ValueError("train_size=%d should be smaller "
780-
"than the number of samples %d" %
781-
(train_size, n))
782-
elif np.asarray(train_size).dtype.kind != 'f':
783-
# Float values are checked during __init__
758+
if np.asarray(train_size).dtype.kind == 'f':
759+
if train_size >= 1.:
760+
raise ValueError("train_size=%f should be smaller "
761+
"than 1.0 or be an integer" % train_size)
762+
elif ((np.asarray(test_size).dtype.kind == 'f') and
763+
((train_size + test_size) > 1.)):
764+
raise ValueError('The sum of test_size and train_size = %f, '
765+
'should be smaller than 1.0. Reduce '
766+
'test_size and/or train_size.' %
767+
(train_size + test_size))
768+
elif np.asarray(train_size).dtype.kind != 'i':
769+
# int va 8000 lues are checked during split based on the input
784770
raise ValueError("Invalid value for train_size: %r" % train_size)
785771

772+
773+
def _validate_shuffle_split(n, test_size, train_size):
774+
if ((test_size is not None) and (np.asarray(test_size).dtype.kind == 'i')
775+
and (test_size >= n)):
776+
raise ValueError('test_size=%d should be smaller '
777+
'than the number of samples %d' % (test_size, n))
778+
779+
if ((train_size is not None) and (np.asarray(train_size).dtype.kind == 'i')
780+
and (train_size >= n)):
781+
raise ValueError("train_size=%d should be smaller "
782+
"than the number of samples %d" % (train_size, n))
783+
786784
if np.asarray(test_size).dtype.kind == 'f':
787785
n_test = ceil(test_size * n)
788786
elif np.asarray(test_size).dtype.kind == 'i':

sklearn/model_selection/tests/test_split.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,15 +636,19 @@ def train_test_split_mock_pandas():
636636

637637

638638
def test_shufflesplit_errors():
639+
# When the {test|train}_size is a float/invalid error is raised at init
640+
assert_raises(ValueError, ShuffleSplit, test_size=None, train_size=None)
639641
assert_raises(ValueError, ShuffleSplit, test_size=2.0)
640642
assert_raises(ValueError, ShuffleSplit, test_size=1.0)
641643
assert_raises(ValueError, ShuffleSplit, test_size=0.1, train_size=0.95)
644+
assert_raises(ValueError, ShuffleSplit, train_size=1j)
645+
646+
# When the {test|train}_size is an int validation is based on the input X
647+
# and happens at split(...)
642648
assert_raises(ValueError, next, ShuffleSplit(test_size=11).split(X))
643649
assert_raises(ValueError, next, ShuffleSplit(test_size=10).split(X))
644650
assert_raises(ValueError, next, ShuffleSplit(test_size=8,
645651
train_size=3).split(X))
646-
assert_raises(ValueError, ShuffleSplit, test_size=None, train_size=None)
647-
assert_raises(ValueError, next, ShuffleSplit(train_size=1j).split(X))
648652

649653

650654
def test_shufflesplit_reproducible():

0 commit comments

Comments
 (0)
0