8000 factorize the reference class to compare to · scikit-learn/scikit-learn@f838fa9 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f838fa9

Browse files
author
Joan Massich
committed
factorize the reference class to compare to
1 parent 01e776e commit f838fa9

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

sklearn/model_selection/_split.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
from itertools import chain, combinations
1818
from collections import Iterable
1919
from math import ceil, floor
20-
import numbers
2120
from abc import ABCMeta, abstractmethod
2221

2322
import numpy as np
2423

2524
from ..utils import indexable, check_random_state, safe_indexing
2625
from ..utils.validation import _num_samples, column_or_1d
2726
from ..utils.validation import check_array
27+
from ..utils.validation import integer_types, floating_types
2828
from ..utils.multiclass import type_of_target
2929
from ..externals.six import with_metaclass
3030
from ..externals.six.moves import zip
@@ -271,7 +271,7 @@ class _BaseKFold(with_metaclass(ABCMeta, BaseCrossValidator)):
271271

272272
@abstractmethod
273273
def __init__(self, n_splits, shuffle, random_state):
274-
if not isinstance(n_splits, (numbers.Integral, np.integer)):
274+
if not isinstance(n_splits, integer_types):
275275
raise ValueError('The number of folds must be of Integral type. '
276276
'%s of type %s was passed.'
277277
% (n_splits, type(n_splits)))
@@ -989,7 +989,7 @@ class _RepeatedSplits(with_metaclass(ABCMeta)):
989989
and shuffle.
990990
"""
991991
def __init__(self, cv, n_repeats=10, random_state=None, **cvargs):
992-
if not isinstance(n_repeats, (np.integer, numbers.Integral)):
992+
if not isinstance(n_repeats, integer_types):
993993
raise ValueError("Number of repetitions must be of Integral type.")
994994

995995
if n_repeats <= 0:
@@ -1643,27 +1643,27 @@ def _validate_shuffle_split_init(test_size, train_size):
16431643
raise ValueError('test_size and train_size can not both be None')
16441644

16451645
if test_size is not None:
1646-
if isinstance(test_size, (float, np.floating)):
1646+
if isinstance(test_size, floating_types):
16471647
if test_size >= 1.:
16481648
raise ValueError(
16491649
'test_size=%f should be smaller '
16501650
'than 1.0 or be an integer' % test_size)
1651-
elif not isinstance(test_size, (numbers.Integral, np.integer)):
1651+
elif not isinstance(test_size, integer_types):
16521652
# int values are checked during split based on the input
16531653
raise ValueError("Invalid value for test_size: %r" % test_size)
16541654

16551655
if train_size is not None:
1656-
if isinstance(train_size, (float, np.floating)):
1656+
if isinstance(train_size, floating_types):
16571657
if train_size >= 1.:
16581658
raise ValueError("train_size=%f should be smaller "
16591659
"than 1.0 or be an integer" % train_size)
1660-
elif (isinstance(test_size, (float, np.floating)) and
1660+
elif (isinstance(test_size, floating_types) and
16611661
(train_size + test_size) > 1.):
16621662
raise ValueError('The sum of test_size and train_size = %f, '
16631663
'should be smaller than 1.0. Reduce '
16641664
'test_size and/or train_size.' %
16651665
(train_size + test_size))
1666-
elif not isinstance(train_size, (numbers.Integral, np.integer)):
1666+
elif not isinstance(train_size, integer_types):
16671667
# int values are checked during split based on the input
16681668
raise ValueError("Invalid value for train_size: %r" % train_size)
16691669

@@ -1676,28 +1676,28 @@ def _validate_shuffle_split(n_samples, test_size, train_size):
16761676
test_size, defaults to 0.1
16771677
"""
16781678
if (test_size is not None and
1679-
isinstance(test_size, (numbers.Integral, np.integer)) and
1679+
isinstance(test_size, integer_types) and
16801680
test_size >= n_samples):
16811681
raise ValueError('test_size=%d should be smaller than the number of '
16821682
'samples %d' % (test_size, n_samples))
16831683

16841684
if (train_size is not None and
1685-
isinstance(train_size, (numbers.Integral, np.integer)) and
1685+
isinstance(train_size, integer_types) and
16861686
train_size >= n_samples):
16871687
raise ValueError("train_size=%d should be smaller than the number of"
16881688
" samples %d" % (train_size, n_samples))
16891689

16901690
if test_size == "default":
16911691
test_size = 0.1
16921692

1693-
if isinstance(test_size, (float, np.floating)):
1693+
if isinstance(test_size, floating_types):
16941694
n_test = ceil(test_size * n_samples)
1695-
elif isinstance(test_size, (numbers.Integral, np.integer)):
1695+
elif isinstance(test_size, integer_types):
16961696
n_test = float(test_size)
16971697

16981698
if train_size is None:
16991699
n_train = n_samples - n_test
1700-
elif isinstance(train_size, (float, np.floating)):
1700+
elif isinstance(train_size, floating_types):
17011701
n_train = floor(train_size * n_samples)
17021702
else:
17031703
n_train = float(train_size)
@@ -1902,7 +1902,7 @@ def check_cv(cv=3, y=None, classifier=False):
19021902
if cv is None:
19031903
cv = 3
19041904

1905-
if isinstance(cv, (numbers.Integral, np.integer)):
1905+
if isinstance(cv, integer_types):
19061906
if (classifier and (y is not None) and
19071907
(type_of_target(y) in ('binary', 'multiclass'))):
19081908
return StratifiedKFold(cv)

sklearn/utils/validation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from ..externals.joblib import Memory
2525

2626

27+
integer_types = (numbers.Integral, np.integer)
28+
floating_types = (float, np.floating)
2729
FLOAT_DTYPES = (np.float64, np.float32, np.float16)
2830

2931
# Silenced by default to reduce verbosity. Turn on at runtime for

0 commit comments

Comments
 (0)
0