8000 Move the *CVS constants to _validation as a dict · raghavrv/scikit-learn@74ec175 · GitHub
[go: up one dir, main page]

Skip to content

Commit 74ec175

Browse files
committed
Move the *CVS constants to _validation as a dict
1 parent ca9517b commit 74ec175

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

sklearn/model_selection/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,3 @@
4949
'permutation_test_score',
5050
'train_test_split',
5151
'validation_curve')
52-
53-
54-
ALL_CVS = (KFold, LabelKFold, LeaveOneLabelOut, LeaveOneOut, LeavePLabelOut,
55-
LeavePOut, ShuffleSplit, LabelShuffleSplit, StratifiedKFold,
56-
StratifiedShuffleSplit, PredefinedSplit)
57-
58-
LABEL_CVS = (LabelKFold, LeaveOneLabelOut, LeavePLabelOut, LabelShuffleSplit,)

sklearn/model_selection/_validation.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,41 @@
2525
from ..utils.validation import _is_arraylike, _num_samples
2626
from ..externals.joblib import Parallel, delayed, logger
2727
from ..metrics.scorer import check_scoring
28-
from ._split import check_cv, _safe_split
2928
from ..exceptions import FitFailedWarning
3029

30+
from ._split import KFold
31+
from ._split import LabelKFold
32+
from ._split import LeaveOneLabelOut
33+
from ._split import LeaveOneOut
34+
from ._split import LeavePLabelOut
35+
from ._split import LeavePOut
36+
from ._split import ShuffleSplit
37+
from ._split import LabelShuffleSplit
38+
from ._split import StratifiedKFold
39+
from ._split import StratifiedShuffleSplit
40+
from ._split import PredefinedSplit
41+
from ._split import check_cv, _safe_split
3142

3243
__all__ = ['cross_val_score', 'cross_val_predict', 'permutation_test_score',
3344
'learning_curve', 'validation_curve']
3445

46+
ALL_CVS = {'KFold': KFold,
47+
'LabelKFold': LabelKFold,
48+
'LeaveOneLabelOut': LeaveOneLabelOut,
49+
'LeaveOneOut': LeaveOneOut,
50+
'LeavePLabelOut': LeavePLabelOut,
51+
'LeavePOut': LeavePOut,
52+
'ShuffleSplit': ShuffleSplit,
53+
'LabelShuffleSplit': LabelShuffleSplit,
54+
'StratifiedKFold': StratifiedKFold,
55+
'StratifiedShuffleSplit': StratifiedShuffleSplit,
56+
'PredefinedSplit': PredefinedSplit}
57+
58+
LABEL_CVS = {'LabelKFold': LabelKFold,
59+
'LeaveOneLabelOut': LeaveOneLabelOut,
60+
'LeavePLabelOut': LeavePLabelOut,
61+
'LabelShuffleSplit': LabelShuffleSplit}
62+
3563

3664
def cross_val_score(estimator, X, y=None, labels=None, scoring=None, cv=None,
3765
n_jobs=1, verbose=0, fit_params=None,

0 commit comments

Comments
 (0)
0