8000 Merge pull request #4294 from rvraghav93/model_selection · scikit-learn/scikit-learn@646e47c · GitHub
[go: up one dir, main page]

Skip to content

Commit 646e47c

Browse files
committed
Merge pull request #4294 from rvraghav93/model_selection
[MRG+1] Reorganize grid_search, cross_validation and learning_curve into model_selection
2 parents 91753dc + bdd94e9 commit 646e47c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+6175
-103
lines changed

doc/whats_new.rst

+10
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ New features
2525
Enhancements
2626
............
2727

28+
- The cross-validation iterators are now modified as cross-validation splitters
29+
which expose a ``split`` method that takes in the data and yields a generator
30+
for the different splits. This change makes it possible to do nested cross-validation
31+
with ease. (`#4294 https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_.
32+
33+
- The :mod:`cross_validation`, :mod:`grid_search` and :mod:`learning_curve`
34+
have been deprecated and the classes and functions have been reorganized into
35+
the :mod:`model_selection` module. (`#4294 https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_.
36+
37+
2838
Bug fixes
2939
.........
3040

sklearn/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@
6262
'ensemble', 'exceptions', 'externals', 'feature_extraction',
6363
'feature_selection', 'gaussian_process', 'grid_search',
6464
'isotonic', 'kernel_approximation', 'kernel_ridge',
65-
'lda', 'learning_curve',
66-
'linear_model', 'manifold', 'metrics', 'mixture', 'multiclass',
65+
'lda', 'learning_curve', 'linear_model', 'manifold', 'metrics',
66+
'mixture', 'model_selection', 'multiclass',
6767
'naive_bayes', 'neighbors', 'neural_network', 'pipeline',
6868
'preprocessing', 'qda', 'random_projection', 'semi_supervised',
6969
'svm', 'tree', 'discriminant_analysis',

sklearn/calibration.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .utils.fixes import signature
2323
from .isotonic import IsotonicRegression
2424
from .svm import LinearSVC
25-
from .cross_validation import check_cv
25+
from .model_selection import check_cv
2626
from .metrics.classification import _check_binary_probabilistic_predictions
2727

2828

@@ -152,7 +152,7 @@ def fit(self, X, y, sample_weight=None):
152152
calibrated_classifier.fit(X, y)
153153
self.calibrated_classifiers_.append(calibrated_classifier)
154154
else:
155-
cv = check_cv(self.cv, X, y, classifier=True)
155+
cv = check_cv(self.cv, y, classifier=True)
156156
fit_parameters = signature(base_estimator.fit).parameters
157157
estimator_name = type(base_estimator).__name__
158158
if (sample_weight is not None
@@ -163,7 +163,7 @@ def fit(self, X, y, sample_weight=None):
163163
base_estimator_sample_weight = None
164164
else:
165165
base_estimator_sample_weight = sample_weight
166-
for train, test in cv:
166+
for train, test in cv.split(X, y):
167167
this_estimator = clone(base_estimator)
168168
if base_estimator_sample_weight is not None:
169169
this_estimator.fit(

sklearn/cluster/tests/test_bicluster.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from scipy.sparse import csr_matrix, issparse
55

6-
from sklearn.grid_search import ParameterGrid
6+
from sklearn.model_selection import ParameterGrid
77

88
from sklearn.utils.testing import assert_equal
99
from sklearn.utils.testing import assert_almost_equal

sklearn/covariance/graph_lasso_.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..utils.validation import check_random_state, check_array
2222
from ..linear_model import lars_path
2323
from ..linear_model import cd_fast
24-
from ..cross_validation import check_cv, cross_val_score
24+
from ..model_selection import check_cv, cross_val_score
2525
from ..externals.joblib import Parallel, delayed
2626
import collections
2727

@@ -580,7 +580,7 @@ def fit(self, X, y=None):
580580
emp_cov = empirical_covariance(
581581
X, assume_centered=self.assume_centered)
582582

583-
cv = check_cv(self.cv, X, y, classifier=False)
583+
cv = check_cv(self.cv, y, classifier=False)
584584

585585
# List of (alpha, scores, covs)
586586
path = list()
@@ -612,14 +612,13 @@ def fit(self, X, y=None):
612612
this_path = Parallel(
613613
n_jobs=self.n_jobs,
614614
verbose=self.verbose
615-
)(
616-
delayed(graph_lasso_path)(
617-
X[train], alphas=alphas,
618-
X_test=X[test], mode=self.mode,
619-
tol=self.tol, enet_tol=self.enet_tol,
620-
max_iter=int(.1 * self.max_iter),
621-
verbose=inner_verbose)
622-
for train, test in cv)
615+
)(delayed(graph_lasso_path)(X[train], alphas=alphas,
616+
X_test=X[test], mode=self.mode,
617+
tol=self.tol,
618+
enet_tol=self.enet_tol,
619+
max_iter=int(.1 * self.max_iter),
620+
verbose=inner_verbose)
621+
for train, test in cv.split(X, y))
623622

624623
# Little danse to transform the list in what we need
625624
covs, _, scores = zip(*this_path)

sklearn/cross_validation.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
from .gaussian_process.kernels import Kernel as GPKernel
3535
from .exceptions import FitFailedWarning
3636

37+
38+
warnings.warn("This module has been deprecated in favor of the "
39+
"model_selection module into which all the refactored classes "
40+
"and functions are moved. Also note that the interface of the "
41+
"new CV iterators are different from that of this module. "
42+
"This module will be removed in 0.19.", DeprecationWarning)
43+
44+
3745
__all__ = ['KFold',
3846
'LabelKFold',
3947
'LeaveOneLabelOut',
@@ -304,7 +312,7 @@ class KFold(_BaseKFold):
304312
305313
See also
306314
--------
307-
StratifiedKFold: take label information into account to avoid building
315+
StratifiedKFold take label information into account to avoid building
308316
folds with imbalanced class distributions (for binary or multiclass
309317
classification tasks).
310318

sklearn/decomposition/tests/test_kernel_pca.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.datasets import make_circles
1010
from sklearn.linear_model import Perceptron
1111
from sklearn.pipeline import Pipeline
12-
from sklearn.grid_search import GridSearchCV
12+
from sklearn.model_selection import GridSearchCV
1313
from sklearn.metrics.pairwise import rbf_kernel
1414

1515

sklearn/ensemble/tests/test_bagging.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@
2121
from sklearn.utils.testing import assert_warns_message
2222

2323
from sklearn.dummy import DummyClassifier, DummyRegressor
24-
from sklearn.grid_search import GridSearchCV, ParameterGrid
24+
from sklearn.model_selection import GridSearchCV, ParameterGrid
2525
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
2626
from sklearn.linear_model import Perceptron, LogisticRegression
2727
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
2828
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2929
from sklearn.svm import SVC, SVR
3030
from sklearn.pipeline import make_pipeline
3131
from sklearn.feature_selection import SelectKBest
32-
from sklearn.cross_validation import train_test_split
32+
from sklearn.model_selection import train_test_split
3333
from sklearn.datasets import load_boston, load_iris, make_hastie_10_2
3434
from sklearn.utils import check_random_state
3535

sklearn/ensemble/tests/test_forest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from sklearn.ensemble import RandomForestClassifier
3939
from sklearn.ensemble import RandomForestRegressor
4040
from sklearn.ensemble import RandomTreesEmbedding
41-
from sklearn.grid_search import GridSearchCV
41+
from sklearn.model_selection import GridSearchCV
4242
from sklearn.svm import LinearSVC
4343
from sklearn.utils.fixes import bincount
4444
from sklearn.utils.validation import check_random_state

sklearn/ensemble/tests/test_voting_classifier.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from sklearn.naive_bayes import GaussianNB
88
from sklearn.ensemble import RandomForestClassifier
99
from sklearn.ensemble import VotingClassifier
10-
from sklearn.grid_search import GridSearchCV
10+
from sklearn.model_selection import GridSearchCV
1111
from sklearn import datasets
12-
from sklearn import cross_validation
12+
from sklearn.model_selection import cross_val_score
1313
from sklearn.datasets import make_multilabel_classification
1414
from sklearn.svm import SVC
1515
from sklearn.multiclass import OneVsRestClassifier
@@ -27,11 +27,7 @@ def test_majority_label_iris():
2727
eclf = VotingClassifier(estimators=[
2828
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
2929
voting='hard')
30-
scores = cross_validation.cross_val_score(eclf,
31-
X,
32-
y,
33-
cv=5,
34-
scoring='accuracy')
30+
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy')
3531
assert_almost_equal(scores.mean(), 0.95, decimal=2)
3632

3733

@@ -55,11 +51,7 @@ def test_weights_iris():
5551
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
5652
voting='soft',
5753
weights=[1, 2, 10])
58-
scores = cross_validation.cross_val_score(eclf,
59-
X,
60-
y,
61-
cv=5,
62-
scoring='accuracy')
54+
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy')
6355
assert_almost_equal(scores.mean(), 0.93, decimal=2)
6456

6557

sklearn/ensemble/tests/test_weight_boosting.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from sklearn.utils.testing import assert_raises, assert_raises_regexp
88

99
from sklearn.base import BaseEstimator
10-
from sklearn.cross_validation import train_test_split
11-
from sklearn.grid_search import GridSearchCV
10+
from sklearn.model_selection import train_test_split
11+
from sklearn.model_selection import GridSearchCV
1212
from sklearn.ensemble import AdaBoostClassifier
1313
from sklearn.ensemble import AdaBoostRegressor
1414
from sklearn.ensemble import weight_boosting

sklearn/exceptions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class FitFailedWarning(RuntimeWarning):
8585
8686
Examples
8787
--------
88-
>>> from sklearn.grid_search import GridSearchCV
88+
>>> from sklearn.model_selection import GridSearchCV
8989
>>> from sklearn.svm import LinearSVC
9090
>>> from sklearn.exceptions import FitFailedWarning
9191
>>> import warnings

sklearn/feature_extraction/tests/test_text.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
1414

15-
from sklearn.cross_validation import train_test_split
16-
from sklearn.cross_validation import cross_val_score
17-
from sklearn.grid_search import GridSearchCV
15+
from sklearn.model_selection import train_test_split
16+
from sklearn.model_selection import cross_val_score
17+
from sklearn.model_selection import GridSearchCV
1818
from sklearn.pipeline import Pipeline
1919
from sklearn.svm import LinearSVC
2020

sklearn/feature_selection/rfe.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from ..base import MetaEstimatorMixin
1515
from ..base import clone
1616
from ..base import is_classifier
17-
from ..cross_validation import check_cv
18-
from ..cross_validation import _safe_split, _score
17+
from ..model_selection import check_cv
18+
from ..model_selection._validation import _safe_split, _score
1919
from ..metrics.scorer import check_scoring
2020
from .base import SelectorMixin
2121

@@ -373,7 +373,7 @@ def fit(self, X, y):
373373
X, y = check_X_y(X, y, "csr")
374374

375375
# Initialization
376-
cv = check_cv(self.cv, X, y, is_classifier(self.estimator))
376+
cv = check_cv(self.cv, y, is_classifier(self.estimator))
377377
scorer = check_scoring(self.estimator, scoring=self.scoring)
378378
n_features = X.shape[1]
379379
n_features_to_select = 1
@@ -382,7 +382,7 @@ def fit(self, X, y):
382382
scores = []
383383

384384
# Cross-validation
385-
for n, (train, test) in enumerate(cv):
385+
for n, (train, test) in enumerate(cv.split(X, y)):
386386
X_train, y_train = _safe_split(self.estimator, X, y, train)
387387
X_test, y_test = _safe_split(self.estimator, X, y, test, train)
388388

@@ -414,7 +414,7 @@ def fit(self, X, y):
414414
self.estimator_ = clone(self.estimator)
415415
self.estimator_.fit(self.transform(X), y)
416416

417-
# Fixing a normalization error, n is equal to len(cv) - 1
418-
# here, the scores are normalized by len(cv)
419-
self.grid_scores_ = scores / len(cv)
417+
# Fixing a normalization error, n is equal to get_n_splits(X, y) - 1
418+
# here, the scores are normalized by get_n_splits(X, y)
419+
self.grid_scores_ = scores / cv.get_n_splits(X, y)
420420
return self

sklearn/feature_selection/tests/test_rfe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.metrics import zero_one_loss
1313
from sklearn.svm import SVC, SVR
1414
from sklearn.ensemble import RandomForestClassifier
15-
from sklearn.cross_validation import cross_val_score
15+
from sklearn.model_selection import cross_val_score
1616

1717
from sklearn.utils import check_random_state
1818
from sklearn.utils.testing import ignore_warnings

sklearn/grid_search.py

+6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
'ParameterSampler', 'RandomizedSearchCV']
3838

3939

40+
warnings.warn("This module has been deprecated in favor of the "
41+
"model_selection module into which all the refactored classes "
42+
"and functions are moved. This module will be removed in 0.19.",
43+
DeprecationWarning)
44+
45+
4046
class ParameterGrid(object):
4147
"""Grid of parameters with a discrete number of values for each.
4248

sklearn/learning_curve.py

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
from .utils.fixes import astype
1818

1919

20+
warnings.warn("This module has been deprecated in favor of the "
21+
"model_selection module into which all the functions are moved."
22+
" This module will be removed in 0.19",
23+
DeprecationWarning)
24+
25+
2026
__all__ = ['learning_curve', 'validation_curve']
2127

2228

sklearn/linear_model/coordinate_descent.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .base import center_data, sparse_center_data
1818
from ..utils import check_array, check_X_y, deprecated
1919
from ..utils.validation import check_random_state
20-
from ..cross_validation import check_cv
20+
from ..model_selection import check_cv
2121
from ..externals.joblib import Parallel, delayed
2222
from ..externals import six
2323
from ..externals.six.moves import xrange
@@ -1120,10 +1120,10 @@ def fit(self, X, y):
11201120
path_params['copy_X'] = False
11211121

11221122
# init cross-validation generator
1123-
cv = check_cv(self.cv, X)
1123+
cv = check_cv(self.cv)
11241124

11251125
# Compute path for all folds and compute MSE to get the best alpha
1126-
folds = list(cv)
1126+
folds = list(cv.split(X))
11271127
best_mse = np.inf
11281128

11291129
# We do a double for loop folded in one, in order to be able to

sklearn/linear_model/least_angle.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .base import LinearModel
2323
from ..base import RegressorMixin
2424
from ..utils import arrayfuncs, as_float_array, check_X_y
25-
from ..cross_validation import check_cv
25+
from ..model_selection import check_cv
2626
from ..exceptions import ConvergenceWarning
2727
from ..externals.joblib import Parallel, delayed
2828
from ..externals.six.moves import xrange
@@ -1079,7 +1079,7 @@ def fit(self, X, y):
10791079
y = as_float_array(y, copy=self.copy_X)
10801080

10811081
# init cross-validation generator
1082-
cv = check_cv(self.cv, X, y, classifier=False)
1082+
cv = check_cv(self.cv, classifier=False)
10831083

10841084
Gram = 'auto' if self.precompute else None
10851085

@@ -1089,7 +1089,7 @@ def fit(self, X, y):
10891089
method=self.method, verbose=max(0, self.verbose - 1),
10901090
normalize=self.normalize, fit_intercept=self.fit_intercept,
10911091
max_iter=self.max_iter, eps=self.eps, positive=self.positive)
1092-
for train, test in cv)
1092+
for train, test in cv.split(X, y))
10931093
all_alphas = np.concatenate(list(zip(*cv_paths))[0])
10941094
# Unique also sorts
10951095
all_alphas = np.unique(all_alphas)

sklearn/linear_model/logistic.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
"""
23
Logistic Regression
34
"""
@@ -32,7 +33,7 @@
3233
from ..utils.fixes import expit
3334
from ..utils.multiclass import check_classification_targets
3435
from ..externals.joblib import Parallel, delayed
35-
from ..cross_validation import check_cv
36+
from ..model_selection import check_cv
3637
from ..externals import six
3738
from ..metrics import SCORERS
3839

@@ -1309,7 +1310,7 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator,
13091310
cv : integer or cross-validation generator
13101311
The default cross-validation generator used is Stratified K-Folds.
13111312
If an integer is provided, then it is the number of folds used.
1312-
See the module :mod:`sklearn.cross_validation` module for the
1313+
See the module :mod:`sklearn.model_selection` module for the
13131314
list of possible cross-validation objects.
13141315
13151316
penalty : str, 'l1' or 'l2'
@@ -1506,8 +1507,8 @@ def fit(self, X, y, sample_weight=None):
15061507
check_consistent_length(X, y)
15071508

15081509
# init cross-validation generator
1509-
cv = check_cv(self.cv, X, y, classifier=True)
1510-
folds = list(cv)
1510+
cv = check_cv(self.cv, y, classifier=True)
1511+
folds = list(cv.split(X, y))
15111512

15121513
self._enc = LabelEncoder()
15131514
self._enc.fit(y)

0 commit comments

Comments
 (0)
0