8000 MAINT Refactor cv, gs and lc into the model_selection module. · raghavrv/scikit-learn@ed8728a · GitHub
[go: up one dir, main page]

Skip to content

Commit ed8728a

Browse files
committed
MAINT Refactor cv, gs and lc into the model_selection module.
Main Commits - Major -------------------- * ENH Reogranize classes/fn from grid_search into search.py * ENH Reogranize classes/fn from cross_validation into split.py * ENH Reogranize cls/fn from cross_validation/learning_curve into validate.py * MAINT Merge _check_cv into check_cv inside the model_selection module * MAINT Update all the imports to point to the model_selection module * FIX use iter_cv to iterate throught the new style/old style cv objs * TST Add tests for the new model_selection members * ENH Wrap the old-style cv obj/iterables instead of using iter_cv * ENH Use scipy's binomial coefficient function comb for calucation of nCk * ENH Few enhancements to the split module * ENH Improve check_cv input validation and docstring * MAINT _get_test_folds(X, y, labels) --> _get_test_folds(labels) * TST if 1d arrays for X introduce any errors * ENH use 1d X arrays for all tests; * ENH X_10 --> X (global var) Minor ----- * ENH _PartitionIterator --> _BaseCrossValidator; * ENH CVIterator --> CVIterableWrapper * TST Import the old SKF locally * FIX/TST Clean up the split module's tests. * DOC Improve documentation of the cv parameter * COSMIT consistently hyphenate cross-validation/cross-validator * TST Calculate n_samples from X * COSMIT Use separate lines for each import. * COSMIT cross_validation_generator --> cross_validator Commits merged manually ----------------------- * FIX Document the random_state attribute in RandomSearchCV * MAINT Use check_cv instead of _check_cv * ENH refactor OVO decision function, use it in SVC for sklearn-like decision_function shape * FIX avoid memory cost when sampling from large parameter grids
1 parent 8d273a1 commit ed8728a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+5401
-87
lines changed

sklearn/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@
6262
'ensemble', 'externals', 'feature_extraction',
6363
'feature_selection', 'gaussian_process', 'grid_search',
6464
'isotonic', 'kernel_approximation', 'kernel_ridge',
65-
'lda', 'learning_curve',
66-
'linear_model', 'manifold', 'metrics', 'mixture', 'multiclass',
65+
'lda', 'learning_curve', 'linear_model', 'manifold', 'metrics',
66+
'mixture', 'model_selection', 'multiclass',
6767
'naive_bayes', 'neighbors', 'neural_network', 'pipeline',
6868
'preprocessing', 'qda', 'random_projection', 'semi_supervised',
6969
'svm', 'tree', 'discriminant_analysis',

sklearn/calibration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .utils.fixes import signature
2323
from .isotonic import IsotonicRegression
2424
from .svm import LinearSVC
25-
from .cross_validation import check_cv
25+
from .model_selection import check_cv
2626
from .metrics.classification import _check_binary_probabilistic_predictions
2727

2828

@@ -152,7 +152,7 @@ def fit(self, X, y, sample_weight=None):
152152
calibrated_classifier.fit(X, y)
153153
self.calibrated_classifiers_.append(calibrated_classifier)
154154
else:
155-
cv = 10000 check_cv(self.cv, X, y, classifier=True)
155+
cv = check_cv(self.cv, y, classifier=True)
156156
fit_parameters = signature(base_estimator.fit).parameters
157157
estimator_name = type(base_estimator).__name__
158158
if (sample_weight is not None
@@ -163,7 +163,7 @@ def fit(self, X, y, sample_weight=None):
163163
base_estimator_sample_weight = None
164164
else:
165165
base_estimator_sample_weight = sample_weight
166-
for train, test in cv:
166+
for train, test in cv.split(X, y):
167167
this_estimator = clone(base_estimator)
168168
if base_estimator_sample_weight is not None:
169169
this_estimator.fit(

sklearn/cluster/tests/test_bicluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from scipy.sparse import csr_matrix, issparse
55

6-
from sklearn.grid_search import ParameterGrid
6+
from sklearn.model_selection import ParameterGrid
77

88
from sklearn.utils.testing import assert_equal
99
from sklearn.utils.testing import assert_almost_equal

sklearn/covariance/graph_lasso_.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..utils.validation import check_random_state, check_array
2222
from ..linear_model import lars_path
2323
from ..linear_model import cd_fast
24-
from ..cross_validation import check_cv, cross_val_score
24+
from ..model_selection import check_cv, cross_val_score
2525
from ..externals.joblib import Parallel, delayed
2626
import collections
2727

@@ -580,7 +580,7 @@ def fit(self, X, y=None):
580580
emp_cov = empirical_covariance(
581581
X, assume_centered=self.assume_centered)
582582

583-
cv = check_cv(self.cv, X, y, classifier=False)
583+
cv = check_cv(self.cv, y, classifier=False)
584584

585585
# List of (alpha, scores, covs)
586586
path = list()
@@ -612,14 +612,13 @@ def fit(self, X, y=None):
612612
this_path = Parallel(
613613
n_jobs=self.n_jobs,
614614
verbose=self.verbose
615-
)(
616-
delayed(graph_lasso_path)(
617-
X[train], alphas=alphas,
618-
X_test=X[test], mode=self.mode,
619-
tol=self.tol, enet_tol=self.enet_tol,
620-
max_iter=int(.1 * self.max_iter),
621-
verbose=inner_verbose)
622-
for train, test in cv)
615+
)(delayed(graph_lasso_path)(X[train], alphas=alphas,
616+
X_test=X[test], mode=self.mode,
617+
tol=self.tol,
618+
enet_tol=self.enet_tol,
619+
max_iter=int(.1 * self.max_iter),
620+
verbose=inner_verbose)
621+
for train, test in cv.split(X, y))
623622

624623
# Little danse to transform the list in what we need
625624
covs, _, scores = zip(*this_path)

sklearn/cross_validation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@
3232
from .metrics.scorer import check_scoring
3333
from .utils.fixes import bincount
3434

35+
36+
warnings.warn("This module has been deprecated in favor of the "
37+
"model_selection module into which all the refactored classes "
38+
"and functions are moved. Also note that the interface of the "
39+
"new CV iterators are different from that of this module. "
40+
"Refer to model_selection for more info.", DeprecationWarning)
41+
42+
3543
__all__ = ['KFold',
3644
'LabelKFold',
3745
'LeaveOneLabelOut',
@@ -302,7 +310,7 @@ class KFold(_BaseKFold):
302310
303311
See also
304312
--------
305-
StratifiedKFold: take label information into account to avoid building
313+
StratifiedKFold take label information into account to avoid building
306314
folds with imbalanced class distributions (for binary or multiclass
307315
classification tasks).
308316

sklearn/decomposition/tests/test_kernel_pca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.datasets import make_circles
1010
from sklearn.linear_model import Perceptron
1111
from sklearn.pipeline import Pipeline
12-
from sklearn.grid_search import GridSearchCV
12+
from sklearn.model_selection import GridSearchCV
1313
from sklearn.metrics.pairwise import rbf_kernel
1414

1515

sklearn/ensemble/tests/test_bagging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@
2121
from sklearn.utils.testing import assert_warns_message
2222

2323
from sklearn.dummy import DummyClassifier, DummyRegressor
24-
from sklearn.grid_search import GridSearchCV, ParameterGrid
24+
from sklearn.model_selection import GridSearchCV, ParameterGrid
2525
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
2626
from sklearn.linear_model import Perceptron, LogisticRegression
2727
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
2828
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2929
from sklearn.svm import SVC, SVR
3030
from sklearn.pipeline import make_pipeline
3131
from sklearn.feature_selection import SelectKBest
32-
from sklearn.cross_validation import train_test_split
32+
from sklearn.model_selection import train_test_split
3333
from sklearn.datasets import load_boston, load_iris, make_hastie_10_2
3434
from sklearn.utils import check_random_state
3535

sklearn/ensemble/tests/test_forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from sklearn.ensemble import RandomForestClassifier
3939
from sklearn.ensemble import RandomForestRegressor
4040
from sklearn.ensemble import RandomTreesEmbedding
41-
from sklearn.grid_search import GridSearchCV
41+
from sklearn.model_selection import GridSearchCV
4242
from sklearn.svm import LinearSVC
4343
from sklearn.utils.fixes import bincount
4444
from sklearn.utils.validation import check_random_state

sklearn/ensemble/tests/test_voting_classifier.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from sklearn.naive_bayes import GaussianNB
88
from sklearn.ensemble import RandomForestClassifier
99
from sklearn.ensemble import VotingClassifier
10-
from sklearn.grid_search import GridSearchCV
10+
from sklearn.model_selection import GridSearchCV
1111
from sklearn import datasets
12-
from sklearn import cross_validation
12+
from sklearn.model_selection import cross_val_score
1313
from sklearn.datasets import make_multilabel_classification
1414
from sklearn.svm import SVC
1515
from sklearn.multiclass import OneVsRestClassifier
@@ -27,11 +27,7 @@ def test_majority_label_iris():
2727
eclf = VotingClassifier(estimators=[
2828
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
2929
voting='hard')
30-
scores = cross_validation.cross_val_score(eclf,
31-
X,
32-
y,
33-
cv=5,
34-
scoring='accuracy')
30+
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy')
3531
assert_almost_equal(scores.mean(), 0.95, decimal=2)
3632

3733

@@ -55,11 +51,7 @@ def test_weights_iris():
5551
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
5652
voting='soft',
5753
weights=[1, 2, 10])
58-
scores = cross_validation.cross_val_score(eclf,
59-
X,
60-
y,
61-
cv=5,
62-
scoring='accuracy')
54+
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy')
6355
assert_almost_equal(scores.mean(), 0.93, decimal=2)
6456

6557

sklearn/ensemble/tests/test_weight_boosting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from sklearn.utils.testing import assert_raises, assert_raises_regexp
88

99
from sklearn.base import BaseEstimator
10-
from sklearn.cross_validation import train_test_split
11-
from sklearn.grid_search import GridSearchCV
10+
from sklearn.model_selection import train_test_split
11+
from sklearn.model_selection import GridSearchCV
1212
from sklearn.ensemble import AdaBoostClassifier
1313
from sklearn.ensemble import AdaBoostRegressor
1414
from sklearn.ensemble import weight_boosting

sklearn/feature_extraction/tests/test_text.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
1414

15-
from sklearn.cross_validation import train_test_split
16-
from sklearn.cross_validation import cross_val_score
17-
from sklearn.grid_search import GridSearchCV
15+
from sklearn.model_selection import train_test_split
16+
from sklearn.model_selection import cross_val_score
17+
from sklearn.model_selection import GridSearchCV
1818
from sklearn.pipeline import Pipeline
1919
from sklearn.svm import LinearSVC
2020

sklearn/feature_selection/rfe.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from ..base import MetaEstimatorMixin
1515
from ..base import clone
1616
from ..base import is_classifier
17-
from ..cross_validation import check_cv
18-
from ..cross_validation import _safe_split, _score
17+
from ..model_selection import check_cv
18+
from ..model_selection.validate import _safe_split, _score
1919
from ..metrics.scorer import check_scoring
2020
from .base import SelectorMixin
2121

@@ -402,7 +402,7 @@ def fit(self, X, y):
402402
"value is set via the estimator initialisation or "
403403
"set_params method.", DeprecationWarning)
404404
# Initialization
405-
cv = check_cv(self.cv, X, y, is_classifier(self.estimator))
405+
cv = check_cv(self.cv, y, is_classifier(self.estimator))
406406
scorer = check_scoring(self.estimator, scoring=self.scoring)
407407
n_features = X.shape[1]
408408
n_features_to_select = 1
@@ -411,7 +411,7 @@ def fit(self, X, y):
411411
scores = []
412412

413413
# Cross-validation
414-
for n, (train, test) in enumerate(cv):
414+
for n, (train, test) in enumerate(cv.split(X, y)):
415415
X_train, y_train = _safe_split(self.estimator, X, y, train)
416416
X_test, y_test = _safe_split(self.estimator, X, y, test, train)
417417

@@ -447,7 +447,7 @@ def fit(self, X, y):
447447
self.estimator_.set_params(**self.estimator_params)
448448
self.estimator_.fit(self.transform(X), y)
449449

450-
# Fixing a normalization error, n is equal to len(cv) - 1
451-
# here, the scores are normalized by len(cv)
452-
self.grid_scores_ = scores / len(cv)
450+
# Fixing a normalization error, n is equal to len_cv - 1
451+
# here, the scores are normalized by len_cv
452+
self.grid_scores_ = scores / cv.n_splits(X, y)
453453
return self

sklearn/feature_selection/tests/test_rfe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.metrics import zero_one_loss
1313
from sklearn.svm import SVC, SVR
1414
from sklearn.ensemble import RandomForestClassifier
15-
from sklearn.cross_validation import cross_val_score
15+
from sklearn.model_selection import cross_val_score
1616

1717
from sklearn.utils import check_random_state
1818
from sklearn.utils.testing import ignore_warnings

sklearn/grid_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636
'ParameterSampler', 'RandomizedSearchCV']
3737

3838

39+
warnings.warn("This module has been deprecated in favor of the "
40+
"model_selection module into which all the refactored classes "
41+
"and functions are moved.", DeprecationWarning)
42+
43+
3944
class ParameterGrid(object):
4045
"""Grid of parameters with a discrete number of values for each.
4146

sklearn/learning_curve.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
from .utils.fixes import astype
1818

1919

20+
warnings.warn("This module has been deprecated in favor of the "
21+
"model_selection module into which all the functions are moved.",
22+
DeprecationWarning)
23+
24+
2025
__all__ = ['learning_curve', 'validation_curve']
2126

2227

sklearn/linear_model/coordinate_descent.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .base import center_data, sparse_center_data
1818
from ..utils import check_array, check_X_y, deprecated
1919
from ..utils.validation import check_random_state
20-
from ..cross_validation import check_cv
20+
from ..model_selection import check_cv
2121
from ..externals.joblib import Parallel, delayed
2222
from ..externals import six
2323
from ..externals.six.moves import xrange
@@ -1129,10 +1129,10 @@ def fit(self, X, y):
11291129
path_params['copy_X'] = False
11301130

11311131
# init cross-validation generator
1132-
cv = check_cv(self.cv, X)
1132+
cv = check_cv(self.cv)
11331133

11341134
# Compute path for all folds and compute MSE to get the best alpha
1135-
folds = list(cv)
1135+
folds = list(cv.split(X))
11361136
best_mse = np.inf
11371137

11381138
# We do a double for loop folded in one, in order to be able to
@@ -1370,6 +1370,7 @@ class ElasticNetCV(LinearModelCV, RegressorMixin):
13701370
dual gap for optimality and continues until it is smaller
13711371
than ``tol``.
13721372
1373+
<<<<<<< HEAD
13731374
cv : int, cross-validation generator or an iterable, optional
13741375
Determines the cross-validation splitting strategy.
13751376
Possible inputs for cv are:
@@ -1382,6 +1383,13 @@ class ElasticNetCV(LinearModelCV, RegressorMixin):
13821383
13831384
Refer :ref:`User Guide <cross_validation>` for the various
13841385
cross-validation strategies that can be used here.
1386+
=======
1387+
cv : integer or cross-validation generator, optional
1388+
If an integer is passed, it is the number of fold (default 3).
1389+
Specific cross-validation objects can be passed, see the
1390+
:mod:`sklearn.model_selection.split` module for the list of
1391+
possible objects.
1392+
>>>>>>> ENH introduce the model_selection module
13851393
13861394
verbose : bool or integer
13871395
Amount of verbosity.
@@ -1852,6 +1860,7 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
18521860
dual gap for optimality and continues until it is smaller
18531861
than ``tol``.
18541862
1863+
<<<<<<< HEAD
18551864
cv : int, cross-validation generator or an iterable, optional
18561865
Determines the cross-validation splitting strategy.
18571866
Possible inputs for cv are:
@@ -1864,6 +1873,13 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
18641873
18651874
Refer :ref:`User Guide <cross_validation>` for the various
18661875
cross-validation strategies that can be used here.
1876+
=======
1877+
cv : integer or cross-validation generator, optional
1878+
If an integer is passed, it is the number of fold (default 3).
1879+
Specific cross-validation objects can be passed, see the
1880+
:mod:`sklearn.model_selection.split` module for the list of
1881+
possible objects.
1882+
>>>>>>> ENH introduce the model_selection module
18671883
18681884
verbose : bool or integer
18691885
Amount of verbosity.
@@ -2009,6 +2025,7 @@ class MultiTaskLassoCV(LinearModelCV, RegressorMixin):
20092025
dual gap for optimality and continues until it is smaller
20102026
than ``tol``.
20112027
2028+
<<<<<<< HEAD
20122029
cv : int, cross-validation generator or an iterable, optional
20132030
Determines the cross-validation splitting strategy.
20142031
Possible inputs for cv are:
@@ -2021,6 +2038,13 @@ class MultiTaskLassoCV(LinearModelCV, RegressorMixin):
20212038
20222039
Refer :ref:`User Guide <cross_validation>` for the various
20232040
cross-validation strategies that can be used here.
2041+
=======
2042+
cv : integer or cross-validation generator, optional
2043+
If an integer is passed, it is the number of fold (default 3).
2044+
Specific cross-validation objects can be passed, see the
2045+
:mod:`sklearn.model_selection.split` module for the list of
2046+
possible objects.
2047+
>>>>>>> ENH introduce the model_selection module
20242048
20252049
verbose : bool or integer
20262050
Amount of verbosity.

sklearn/linear_model/least_angle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .base import LinearModel
2323
from ..base import RegressorMixin
2424
from ..utils import arrayfuncs, as_float_array, check_X_y
25-
from ..cross_validation import check_cv
25+
from ..model_selection import check_cv
2626
from ..utils import ConvergenceWarning
2727
from ..externals.joblib import Parallel, delayed
2828
from ..externals.six.moves import xrange
@@ -1076,7 +1076,7 @@ def fit(self, X, y):
10761076
X, y = check_X_y(X, y, y_numeric=True)
10771077

10781078
# init cross-validation generator
1079-
cv = check_cv(self.cv, X, y, classifier=False)
1079+
cv = check_cv(self.cv, classifier=False)
10801080

10811081
Gram = 'auto' if self.precompute else None
10821082

0 commit comments

Comments
 (0)
0