8000 ENH Various enhancements to the model_selection module · scikit-learn/scikit-learn@8f7bea9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8f7bea9

Browse files
committed
ENH Various enhancements to the model_selection module
1 parent c0bc2f8 commit 8f7bea9

File tree

3 files changed

+31
-40
lines changed

3 files changed

+31
-40
lines changed

sklearn/model_selection/_split.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,3 +1599,21 @@ def _build_repr(self):
15991599
params[key] = value
16001600

16011601
return '%s(%s)' % (class_name, _pprint(params, offset=len(class_name)))
1602+
1603+
1604+
ALL_CVS = {'KFold': KFold,
1605+
'LabelKFold': LabelKFold,
1606+
'LeaveOneLabelOut': LeaveOneLabelOut,
1607+
'LeaveOneOut': LeaveOneOut,
1608+
'LeavePLabelOut': LeavePLabelOut,
1609+
'LeavePOut': LeavePOut,
1610+
'ShuffleSplit': ShuffleSplit,
1611+
'LabelShuffleSplit': LabelShuffleSplit,
1612+
'StratifiedKFold': StratifiedKFold,
1613+
'StratifiedShuffleSplit': StratifiedShuffleSplit,
1614+
'PredefinedSplit': PredefinedSplit}
1615+
1616+
LABEL_CVS = {'LabelKFold': LabelKFold,
1617+
'LeaveOneLabelOut': LeaveOneLabelOut,
1618+
'LeavePLabelOut': LeavePLabelOut,
1619+
'LabelShuffleSplit': LabelShuffleSplit}

sklearn/model_selection/_validation.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,39 +27,9 @@
2727
from ..metrics.scorer import check_scoring
2828
from ..exceptions import FitFailedWarning
2929

30-
from ._split import KFold
31-
from ._split import LabelKFold
32-
from ._split import LeaveOneLabelOut
33-
from ._split import LeaveOneOut
34-
from ._split import LeavePLabelOut
35-
from ._split import LeavePOut
36-
from ._split import ShuffleSplit
37-
from ._split import LabelShuffleSplit
38-
from ._split import StratifiedKFold
39-
from ._split import StratifiedShuffleSplit
40-
from ._split import PredefinedSplit
41-
from ._split import check_cv, _safe_split
42-
4330
__all__ = ['cross_val_score', 'cross_val_predict', 'permutation_test_score',
4431
'learning_curve', 'validation_curve']
4532

46-
ALL_CVS = {'KFold': KFold,
47-
'LabelKFold': LabelKFold,
48-
'LeaveOneLabelOut': LeaveOneLabelOut,
49-
'LeaveOneOut': LeaveOneOut,
50-
'LeavePLabelOut': LeavePLabelOut,
51-
'LeavePOut': LeavePOut,
52-
'ShuffleSplit': ShuffleSplit,
53-
'LabelShuffleSplit': LabelShuffleSplit,
54-
'StratifiedKFold': StratifiedKFold,
55-
'StratifiedShuffleSplit': StratifiedShuffleSplit,
56-
'PredefinedSplit': PredefinedSplit}
57-
58-
LABEL_CVS = {'LabelKFold': LabelKFold,
59-
'LeaveOneLabelOut': LeaveOneLabelOut,
60-
'LeavePLabelOut': LeavePLabelOut,
61-
'LabelShuffleSplit': LabelShuffleSplit}
62-
6333

6434
def cross_val_score(estimator, X, y=None, labels=None, scoring=None, cv=None,
6535
n_jobs=1, verbose=0, fit_params=None,

sklearn/model_selection/tests/test_search.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@
3434

3535
from sklearn.model_selection import KFold
3636
from sklearn.model_selection import StratifiedKFold
37-
from sklearn.model_selection import StratifiedShuffleSplit
38-
from sklearn.model_selection import LeaveOneLabelOut
39-
from sklearn.model_selection import LeavePLabelOut
40-
from sklearn.model_selection import LabelKFold
41-
from sklearn.model_selection import LabelShuffleSplit
4237
from sklearn.model_selection import GridSearchCV
4338
from sklearn.model_selection import RandomizedSearchCV
4439
from sklearn.model_selection import ParameterGrid
@@ -47,6 +42,7 @@
4742
# TODO Import from sklearn.exceptions once merged.
4843
from sklearn.base import ChangedBehaviorWarning
4944
from sklearn.model_selection._validation import FitFailedWarning
45+
from sklearn.model_selection._split import ALL_CVS, LABEL_CVS
5046

5147
from sklearn.svm import LinearSVC, SVC
5248
from sklearn.tree import DecisionTreeRegressor
@@ -60,6 +56,14 @@
6056
from sklearn.pipeline import Pipeline
6157

6258

59+
def initialize_cross_validators(CVClass):
60+
# set parameters to initialize the cross-validators
61+
if CVClass is ALL_CVS['LeavePLabelOut']:
62+
return CVClass(n_labels=2)
63+
if CVClass is ALL_CVS['LeavePOut']:
64+
return CVClass(p=2)
65+
66+
6367
# Neither of the following two estimators inherit from BaseEstimator,
6468
# to test hyperparameter search on user-defined classifiers.
6569
class MockClassifier(object):
@@ -235,17 +239,16 @@ def test_grid_search_labels():
235239
clf = LinearSVC(random_state=0)
236240
grid = {'C': [1]}
237241

238-
label_cvs = [LeaveOneLabelOut(), LeavePLabelOut(2), LabelKFold(),
239-
LabelShuffleSplit()]
240-
for cv in label_cvs:
242+
for _, CVClass in LABEL_CVS.iteritems():
243+
cv = initialize_cross_validators(CVClass)
241244
gs = GridSearchCV(clf, grid, cv=cv)
242245
assert_raise_message(ValueError,
243246
"The labels parameter should not be None",
244247
gs.fit, X, y)
245248
gs.fit(X, y, labels)
246249

247-
non_label_cvs = [StratifiedKFold(), StratifiedShuffleSplit()]
248-
for cv in non_label_cvs:
250+
for _, CVClass in (set(ALL_CVS.iteritems()) - set(LABEL_CVS.iteritems())):
251+
cv = initialize_cross_validators(CVClass)
249252
gs = GridSearchCV(clf, grid, cv=cv)
250253
# Should not raise an error
251254
gs.fit(X, y)

0 commit comments

Comments
 (0)
0