8000 factorize the reference class to compare to · scikit-learn/scikit-learn@f838fa9 · GitHub
[go: up one dir, main page]

Skip to content

Commit f838fa9

Browse files
author
Joan Massich
committed
factorize the reference class to compare to
1 parent 01e776e commit f838fa9

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

sklearn/model_selection/_split.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
from itertools import chain, combinations
1818
from collections import Iterable
1919
from math import ceil, floor
20-
import numbers
2120
from abc import ABCMeta, abstractmethod
2221

2322
import numpy as np
2423

2524
from ..utils import indexable, check_random_state, safe_indexing
2625
from ..utils.validation import _num_samples, column_or_1d
2726
from ..utils.validation import check_array
27+
from ..utils.validation import integer_types, floating_types
2828
from ..utils.multiclass import type_of_target
2929
from ..externals.six import with_metaclass
3030
from ..externals.six.moves import zip
@@ -271,7 +271,7 @@ class _BaseKFold(with_metaclass(ABCMeta, BaseCrossValidator)):
271271

272272
@abstractmethod
273273
def __init__(self, n_splits, shuffle, random_state):
274-
if not isinstance(n_splits, (numbers.Integral, np.integer)):
274+
if not isinstance(n_splits, integer_types):
275275
raise ValueError('The number of folds must be of Integral type. '
276276
'%s of type %s was passed.'
277277
% (n_splits, type(n_splits)))
@@ -989,7 +989,7 @@ class _RepeatedSplits(with_metaclass(ABCMeta)):
989989
and shuffle.
990990
"""
991991
def __init__(self, cv, n_repeats=10, random_state=None, **cvargs):
992-
if not isinstance(n_repeats, (np.integer, numbers.Integral)):
992+
if not isinstance(n_repeats, integer_types):
993993
raise ValueError("Number of repetitions must be of Integral type.")
994994

995995
if n_repeats <= 0:
@@ -1643,27 +1643,27 @@ def _validate_shuffle_split_init(test_size, train_size):
16431643
raise ValueError('test_size and train_size can not both be None')
16441644

16451645
if test_size is not None:
1646-
if isinstance(test_size, (float, np.floating)):
1646+
if isinstance(test_size, floating_types):
16471647
if test_size >= 1.:
16481648
raise ValueError(
16491649
'test_size=%f should be smaller '
16501650
'than 1.0 or be an integer' % test_size)
1651-
elif not isinstance(test_size, (numbers.Integral, np.integer)):
1651+
elif not isinstance(test_size, integer_types):
16521652
# int values are checked during split based on the input
16531653
raise ValueError("Invalid value for test_size: %r" % test_size)
16541654

16551655
if train_size is not None:
1656-
if isinstance(train_size, (float, np.floating)):
1656+
if isinstance(train_size, floating_types):
16571657
if train_size >= 1.:
16581658
raise ValueError("train_size=%f should be smaller "
16591659
"than 1.0 or be an integer" % train_size)
1660-
elif (isinstance(test_size, (float, np.floating)) and
1660+
elif (isinstance(test_size, floating_types) and
16611661
(train_size + test_size) > 1.):
16621662
raise ValueError('The sum of test_size and train_size = %f, '
16631663
'should be smaller than 1.0. Reduce '
16641664
'test_size and/or train_size.' %
16651665
(train_size + test_size))
1666-
elif not isinstance(train_size, (numbers.Integral, np.integer)):
1666+
elif not isinstance(train_size, integer_types):
16671667
# int values are checked during split based on the input
16681668
raise ValueError("Invalid value for train_size: %r" % train_size)
16691669

@@ -1676,28 +1676,28 @@ def _validate_shuffle_split(n_samples, test_size, train_size):
16761676
test_size, defaults to 0.1
16771677
"""
16781678
if (test_size is not None and
1679-
isinstance(test_size, (numbers.Integral, np.integer)) and
1679+
isinstance(test_size, integer_types) and
16801680
test_size >= n_samples):
16811681
raise ValueError('test_size=%d should be smaller than the number of '
16821682
'samples %d' % (test_size, n_samples))
16831683

16841684
if (train_size is not None and
1685-
isinstance(train_size, (numbers.Integral, np.integer)) and
1685+
isinstance(train_size, integer_types) and
16861686
train_size >= n_samples):
16871687
raise ValueError("train_size=%d should be smaller than the number of"
16881688
" samples %d" % (train_size, n_samples))
16891689

16901690
if test_size == "default":
16911691
test_size = 0.1
16921692

1693-
if isinstance(test_size, (float, np.floating)):
1693+
if isinstance(test_size, floating_types):
16941694
n_test = ceil(test_size * n_samples)
1695-
elif isinstance(test_size, (numbers.Integral, np.integer)):
1695+
elif isinstance(test_size, integer_types):
16961696
n_test = float(test_size)
16971697

16981698
if train_size is None:
16991699
n_train = n_samples - n_test
1700-
elif isinstance(train_size, (float, np.floating)):
1700+
elif isinstance(train_size, floating_types):
17011701
n_train = floor(train_size * n_samples)
17021702
else:
17031703
n_train = float(train_size)
@@ -1902,7 +1902,7 @@ def check_cv(cv=3, y=None, classifier=False):
19021902
if cv is None:
19031903
cv = 3
19041904

1905-
if isinstance(cv, (numbers.Integral, np.integer)):
1905+
if isinstance(cv, integer_types):
19061906
if (classifier and (y is not None) and
19071907
(type_of_target(y) in ('binary', 'multiclass'))):
19081908
return StratifiedKFold(cv)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from ..externals.joblib import Memory
2525

2626

27+
integer_types = (numbers.Integral, np.integer)
28+
floating_types = (float, np.floating)
2729
FLOAT_DTYPES = (np.float64, np.float32, np.float16)
2830

2931
# Silenced by default to reduce verbosity. Turn on at runtime for

0 commit comments

Comments
 (0)
0