8000 Main Commits - Major · scikit-learn/scikit-learn@3f8743f · GitHub
[go: up one dir, main page]

Skip to content

Commit 3f8743f

Browse files
raghavrvamueller
authored andcommitted
Main Commits - Major
-------------------- * ENH Reogranize classes/fn from grid_search into search.py * ENH Reogranize classes/fn from cross_validation into split.py * ENH Reogranize cls/fn from cross_validation/learning_curve into validate.py * MAINT Merge _check_cv into check_cv inside the model_selection module * MAINT Update all the imports to point to the model_selection module * FIX use iter_cv to iterate throught the new style/old style cv objs * TST Add tests for the new model_selection members * ENH Wrap the old-style cv obj/iterables instead of using iter_cv * ENH Use scipy's binomial coefficient function comb for calucation of nCk * ENH Few enhancements to the split module * ENH Improve check_cv input validation and docstring * MAINT _get_test_folds(X, y, labels) --> _get_test_folds(labels) * TST if 1d arrays for X introduce any errors * ENH use 1d X arrays for all tests; * ENH X_10 --> X (global var) Minor ----- * ENH _PartitionIterator --> _BaseCrossValidator; * ENH CVIterator --> CVIterableWrapper * TST Import the old SKF locally * FIX/TST Clean up the split module's tests. * DOC Improve documentation of the cv parameter * COSMIT consistently hyphenate cross-validation/cross-validator * TST Calculate n_samples from X * COSMIT Use separate lines for each import. * COSMIT cross_validation_generator --> cross_validator Commits merged manually ----------------------- * FIX Document the random_state attribute in RandomSearchCV * MAINT Use check_cv instead of _check_cv * ENH refactor OVO decision function, use it in SVC for sklearn-like decision_function shape * FIX avoid memory cost when sampling from large parameter grids ENH Major to Minor incremental enhancements to the model_selection Squashed commit messages - (For reference) Major ----- * ENH p --> n_labels * FIX *ShuffleSplit: all float/invalid type errors at init and int error at split * FIX make PredefinedSplit accept test_folds in constructor; Cleanup docstrings * ENH+TST KFold: make rng to be generated at every split call for reproducibility * FIX/MAINT KFold: make shuffle a public attr * FIX Make CVIterableWrapper private. * FIX reuse len_cv instead of recalculating it * FIX Prevent adding *SearchCV estimators from the old grid_search module * re-FIX In all_estimators: the sorting to use only the 1st item (name) To avoid collision between the old and the new GridSearch classes. * FIX test_validate.py: Use 2D X (1D X is being detected as a single sample) * MAINT validate.py --> validation.py * MAINT make the submodules private * MAINT Support old cv/gs/lc until 0.19 * FIX/MAINT n_splits --> get_n_splits * FIX/TST test_logistic.py/test_ovr_multinomial_iris: pass predefined folds as an iterable * MAINT expose BaseCrossValidator * Update the model_selection module with changes from master - From #5161 - - MAINT remove redundant p variable - - Add check for sparse prediction in cross_val_predict - From #5201 - DOC improve random_state param doc - From #5190 - LabelKFold and test - From #4583 - LabelShuffleSplit and tests - From #5300 - shuffle the `labels` not the `indxs` in LabelKFold + tests - From #5378 - Make the GridSearchCV docs more accurate. - From #5458 - Remove shuffle from LabelKFold - From #5466(#4270) - Gaussian Process by Jan Metzen - From #4826 - Move custom error / warnings into sklearn.exception Minor ----- * ENH Make the KFold shuffling test stronger * FIX/DOC Use the higher level model_selection module as ref * DOC in check_cv "y : array-like, optional" * DOC a supervised learning problem --> supervised learning problems * DOC cross-validators --> cross-validation strategies * DOC Correct Olivier Grisel's name ;) * MINOR/FIX cv_indices --> kfold * FIX/DOC Align the 'See also' section of the new KFold, LeaveOneOut * TST/FIX imports on separate lines * FIX use __class__ instead of classmethod * TST/FIX import directly from model_selection * COSMIT Relocate the random_state documentation * COSMIT remove pass * MAINT Remove deprecation warnings from old tests * FIX correct import at test_split * FIX/MAINT Move P_sparse, X, y defns to top; rm unused W_sparse, X_sparse * FIX random state to avoid doctes 8000 t failure * TST n_splits and split wrapping of _CVIterableWrapper * FIX/MAINT Use multilabel indicator matrix directly * TST/DOC clarify why we conflate classes 0 and 1 * DOC add comment that this was taken from BaseEstimator * FIX use of labels is not needed in stratified k fold * Fix cross_validation reference * Fix the labels param doc FIX/DOC/MAINT Addressing the review comments by Arnaud and Andy COSMIT Sort the members alphabetically COSMIT len_cv --> n_splits COSMIT Merge 2 if; FIX Use kwargs DOC Add my name to the authors :D DOC make labels parameter consistent FIX Remove hack for boolean indices; + COSMIT idx --> indices; DOC Add Returns COSMIT preds --> predictions DOC Add Returns and neatly arrange X, y, labels FIX idx(s)/ind(s)--> indice(s) COSMIT Merge if and else to elif COSMIT n --> n_samples COSMIT Use bincount only once COSMIT cls --> class_i / class_i (ith class indices) --> perm_indices_class_i FIX/ENH/TST Addressing the final reviews COSMIT c --> count FIX/TST make check_cv raise ValueError for string cv value TST nested cv (gs inside cross_val_score) works for diff cvs FIX/ENH Raise ValueError when labels is None for label based cvs; TST if labels is being passed correctly to the cv and that the ValueError is being propagated to the cross_val_score/predict and grid search FIX pass labels to cross_val_score FIX use make_classification DOC Add Returns; COSMIT Remove scaffolding TST add a test to check the _build_repr helper REVERT the old GS/RS should also be tested by the common tests. ENH Add a tuple of all/label based CVS FIX raise VE even at get_n_splits if labels is None FIX Fabian's comments PEP8
1 parent 409c888 commit 3f8743f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+6175
-103
lines changed

doc/whats_new.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ New features
2525
Enhancements
2626
............
2727

28+
- The cross-validation iterators are now modified as cross-validation splitters
29+
which expose a ``split`` method that takes in the data and yields a generator
30+
for the different splits. This change makes it possible to do nested cross-validation
31+
with ease. (`#4294 https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_.
32+
33+
- The :mod:`cross_validation`, :mod:`grid_search` and :mod:`learning_curve`
34+
have been deprecated and the classes and functions have been reorganized into
35+
the :mod:`model_selection` module. (`#4294 https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_.
36+
37+
2838
Bug fixes
2939
.........
3040

sklearn/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@
6262
'ensemble', 'exceptions', 'externals', 'feature_extraction',
6363
'feature_selection', 'gaussian_process', 'grid_search',
6464
'isotonic', 'kernel_approximation', 'kernel_ridge',
65-
'lda', 'learning_curve',
66-
'linear_model', 'manifold', 'metrics', 'mixture', 'multiclass',
65+
'lda', 'learning_curve', 'linear_model', 'manifold', 'metrics',
66+
'mixture', 'model_selection', 'multiclass',
6767
'naive_bayes', 'neighbors', 'neural_network', 'pipeline',
6868
'preprocessing', 'qda', 'random_projection', 'semi_supervised',
6969
'svm', 'tree', 'discriminant_analysis',

sklearn/calibration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .utils.fixes import signature
2323
from .isotonic import IsotonicRegression
2424
from .svm import LinearSVC
25-
from .cross_validation import check_cv
25+
from .model_selection import check_cv
2626
from .metrics.classification import _check_binary_probabilistic_predictions
2727

2828

@@ -152,7 +152,7 @@ def fit(self, X, y, sample_weight=None):
152152
calibrated_classifier.fit(X, y)
153153
self.calibrated_classifiers_.append(calibrated_classifier)
154154
else:
155-
cv = check_cv(self.cv, X, y, classifier=True)
155+
cv = check_cv(self.cv, y, classifier=True)
156156
fit_parameters = signature(base_estimator.fit).parameters
157157
estimator_name = type(base_estimator).__name__
158158
if (sample_weight is not None
@@ -163,7 +163,7 @@ def fit(self, X, y, sample_weight=None):
163163
base_estimator_sample_weight = None
164164
else:
165165
base_estimator_sample_weight = sample_weight
166-
for train, test in cv:
166+
for train, test in cv.split(X, y):
167167
this_estimator = clone(base_estimator)
168168
if base_estimator_sample_weight is not None:
169169
this_estimator.fit(

sklearn/cluster/tests/test_bicluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from scipy.sparse import csr_matrix, issparse
55

6-
from sklearn.grid_search import ParameterGrid
6+
from sklearn.model_selection import ParameterGrid
77

88
from sklearn.utils.testing import assert_equal
99
from sklearn.utils.testing import assert_almost_equal

sklearn/covariance/graph_lasso_.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..utils.validation import check_random_state, check_array
2222
from ..linear_model import lars_path
2323
from ..linear_model import cd_fast
24-
from ..cross_validation import check_cv, cross_val_score
24+
from ..model_selection import check_cv, cross_val_score
2525
from ..externals.joblib import Parallel, delayed
2626
import collections
2727

@@ -580,7 +580,7 @@ def fit(self, X, y=None):
580580
emp_cov = empirical_covariance(
581581
X, assume_centered=self.assume_centered)
582582

583-
cv = check_cv(self.cv, X, y, classifier=False)
583+
cv = check_cv(self.cv, y, classifier=False)
584584

585585
# List of (alpha, scores, covs)
586586
path = list()
@@ -612,14 +612,13 @@ def fit(self, X, y=None):
612612
this_path = Parallel(
613613
n_jobs=self.n_jobs,
614614
verbose=self.verbose
615-
)(
616-
delayed(graph_lasso_path)(
617-
X[train], alphas=alphas,
618-
X_test=X[test], mode=self.mode,
619-
tol=self.tol, enet_tol=self.enet_tol,
620-
max_iter=int(.1 * self.max_iter),
621-
verbose=inner_verbose)
622-
for train, test in cv)
615+
)(delayed(graph_lasso_path)(X[train], alphas=alphas,
616+
X_test=X[test], mode=self.mode,
617+
tol=self.tol,
618+
enet_tol=self.enet_tol,
619+
max_iter=int(.1 * self.max_iter),
620+
verbose=inner_verbose)
621+
for train, test in cv.sp 10000 lit(X, y))
623622

624623
# Little danse to transform the list in what we need
625624
covs, _, scores = zip(*this_path)

sklearn/cross_validation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
from .gaussian_process.kernels import Kernel as GPKernel
3535
from .exceptions import FitFailedWarning
3636

37+
38+
warnings.warn("This module has been deprecated in favor of the "
39+
"model_selection module into which all the refactored classes "
40+
"and functions are moved. Also note that the interface of the "
41+
"new CV iterators are different from that of this module. "
42+
"This module will be removed in 0.19.", DeprecationWarning)
43+
44+
3745
__all__ = ['KFold',
3846
'LabelKFold',
3947
'LeaveOneLabelOut',
@@ -304,7 +312,7 @@ class KFold(_BaseKFold):
304312
305313
See also
306314
--------
307-
StratifiedKFold: take label information into account to avoid building
315+
StratifiedKFold take label information into account to avoid building
308316
folds with imbalanced class distributions (for binary or multiclass
309317
classification tasks).
310318

sklearn/decomposition/tests/test_kernel_pca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.datasets import make_circles
1010
from sklearn.linear_model import Perceptron
1111
from sklearn.pipeline import Pipeline
12-
from sklearn.grid_search import GridSearchCV
12+
from sklearn.model_selection import GridSearchCV
1313
from sklearn.metrics.pairwise import rbf_kernel
1414

1515

sklearn/ensemble/tests/test_bagging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@
2121
from sklearn.utils.testing import assert_warns_message
2222

2323
from sklearn.dummy import DummyClassifier, DummyRegressor
24-
from sklearn.grid_search import GridSearchCV, ParameterGrid
24+
from sklearn.model_selection import GridSearchCV, ParameterGrid
2525
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
2626
from sklearn.linear_model import Perceptron, LogisticRegression
2727
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
2828
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2929
from sklearn.svm import SVC, SVR
3030
from sklearn.pipeline import make_pipeline
3131
from sklearn.feature_selection import SelectKBest
32-
from sklearn.cross_validation import train_test_split
32+
from sklearn.model_selection import train_test_split
3333
from sklearn.datasets import load_boston, load_iris, make_hastie_10_2
3434
from sklearn.utils import check_random_state
3535

sklearn/ensemble/tests/test_forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from sklearn.ensemble import RandomForestClassifier
3939
from sklearn.ensemble import RandomForestRegressor
4040
from sklearn.ensemble import RandomTreesEmbedding
41-
from sklearn.grid_search import GridSearchCV
41+
from sklearn.model_selection import GridSearchCV
4242
from sklearn.svm import LinearSVC
4343
from sklearn.utils.fixes import bincount
4444
from sklearn.utils.validation import check_random_state

sklearn/ensemble/tests/test_voting_classifier.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from sklearn.naive_bayes import GaussianNB
88
from sklearn.ensemble import RandomForestClassifier
99
from sklearn.ensemble import VotingClassifier
10-
from sklearn.grid_search import GridSearchCV
10+
from sklearn.model_selection import GridSearchCV
1111
from sklearn import datasets
12-
from sklearn import cross_validation
12+
from sklearn.model_selection import cross_val_score
1313
from sklearn.datasets import make_multilabel_classification
1414
from sklearn.svm import SVC
1515
from sklearn.multiclass import OneVsRestClassifier
@@ -27,11 +27,7 @@ def test_majority_label_iris():
2727
eclf = VotingClassifier(estimators=[
2828
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
2929
voting='hard')
30-
scores = cross_validation.cross_val_score(eclf,
31-
X,
32-
y,
33-
cv=5,
34-
scoring='accuracy')
30+
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy')
3531
assert_almost_equal(scores.mean(), 0.95, decimal=2)
3632

3733

@@ -55,11 +51,7 @@ def test_weights_iris():
5551
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
5652
voting='soft',
5753
weights=[1, 2, 10])
58-
scores = cross_validation.cross_val_score(eclf,
59-
X,
60-
y,
61-
cv=5,
62-
scoring='accuracy')
54+
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy')
6355
assert_almost_equal(scores.mean(), 0.93, decimal=2)
6456

6557

sklearn/ensemble/tests/test_weight_boosting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from sklearn.utils.testing import assert_raises, assert_raises_regexp
88

99
from sklearn.base import BaseEstimator
10-
from sklearn.cross_validation import train_test_split
11-
from sklearn.grid_search import GridSearchCV
10+
from sklearn.model_selection import train_test_split
11+
from sklearn.model_selection import GridSearchCV
1212
from sklearn.ensemble import AdaBoostClassifier
1313
from sklearn.ensemble import AdaBoostRegressor
1414
from sklearn.ensemble import weight_boosting

sklearn/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class FitFailedWarning(RuntimeWarning):
8585
8686
Examples
8787
--------
88-
>>> from sklearn.grid_search import GridSearchCV
88+
>>> from sklearn.model_selection import GridSearchCV
8989
>>> from sklearn.svm import LinearSVC
9090
>>> from sklearn.exceptions import FitFailedWarning
9191
>>> import warnings

sklearn/feature_extraction/tests/test_text.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
1414

15-
from sklearn.cross_validation import train_test_split
16-
from sklearn.cross_validation import cross_val_score
17-
from sklearn.grid_search import GridSearchCV
15+
from sklearn.model_selection import train_test_split
16+
from sklearn.model_selection import cross_val_score
17+
from sklearn.model_selection import GridSearchCV
1818
from sklearn.pipeline import Pipeline
1919
from sklearn.svm import LinearSVC
2020

sklearn/feature_selection/rfe.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from ..base import MetaEstimatorMixin
1515< F438 div class="diff-text-inner">from ..base import clone
1616
from ..base import is_classifier
17-
from ..cross_validation import check_cv
18-
from ..cross_validation import _safe_split, _score
17+
from ..model_selection import check_cv
18+
from ..model_selection._validation import _safe_split, _score
1919
from ..metrics.scorer import check_scoring
2020
from .base import SelectorMixin
2121

@@ -373,7 +373,7 @@ def fit(self, X, y):
373373
X, y = check_X_y(X, y, "csr")
374374

375375
# Initialization
376-
cv = check_cv(self.cv, X, y, is_classifier(self.estimator))
376+
cv = check_cv(self.cv, y, is_classifier(self.estimator))
377377
scorer = check_scoring(self.estimator, scoring=self.scoring)
378378
n_features = X.shape[1]
379379
n_features_to_select = 1
@@ -382,7 +382,7 @@ def fit(self, X, y):
382382
scores = []
383383

384384
# Cross-validation
385-
for n, (train, test) in enumerate(cv):
385+
for n, (train, test) in enumerate(cv.split(X, y)):
386386
X_train, y_train = _safe_split(self.estimator, X, y, train)
387387
X_test, y_test = _safe_split(self.estimator, X, y, test, train)
388388

@@ -414,7 +414,7 @@ def fit(self, X, y):
414414
self.estimator_ = clone(self.estimator)
415415
self.estimator_.fit(self.transform(X), y)
416416

417-
# Fixing a normalization error, n is equal to len(cv) - 1
418-
# here, the scores are normalized by len(cv)
419-
self.grid_scores_ = scores / len(cv)
417+
# Fixing a normalization error, n is equal to get_n_splits(X, y) - 1
418+
# here, the scores are normalized by get_n_splits(X, y)
419+
self.grid_scores_ = scores / cv.get_n_splits(X, y)
420420
return self

sklearn/feature_selection/tests/test_rfe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.metrics import zero_one_loss
1313
from sklearn.svm import SVC, SVR
1414
from sklearn.ensemble import RandomForestClassifier
15-
from sklearn.cross_validation import cross_val_score
15+
from sklearn.model_selection import cross_val_score
1616

1717
from sklearn.utils import check_random_state
1818
from sklearn.utils.testing import ignore_warnings

sklearn/grid_search.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
'ParameterSampler', 'RandomizedSearchCV']
3838

3939

40+
warnings.warn("This module has been deprecated in favor of the "
41+
"model_selection module into which all the refactored classes "
42+
"and functions are moved. This module will be removed in 0.19.",
43+
DeprecationWarning)
44+
45+
4046
class ParameterGrid(object):
4147
"""Grid of parameters with a discrete number of values for each.
4248

sklearn/learning_curve.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
from .utils.fixes import astype
1818

1919

20+
warnings.warn("This module has been deprecated in favor of the "
21+
"model_selection module into which all the functions are moved."
22+
" This module will be removed in 0.19",
23+
DeprecationWarning)
24+
25+
2026
__all__ = ['learning_curve', 'validation_curve']
2127

2228

sklearn/linear_model/coordinate_descent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .base import center_data, sparse_center_data
1818
from ..utils import check_array, check_X_y, deprecated
1919
from ..utils.validation import check_random_state
20-
from ..cross_validation import check_cv
20+
from ..model_selection import check_cv
2121
from ..externals.joblib import Parallel, delayed
2222
from ..externals import six
2323
from ..externals.six.moves import xrange
@@ -1120,10 +1120,10 @@ def fit(self, X, y):
11201120
path_params['copy_X'] = False
11211121

11221122
# init cross-validation generator
1123-
cv = check_cv(self.cv, X)
1123+
cv = check_cv(self.cv)
11241124

11251125
# Compute path for all folds and compute MSE to get the best alpha
1126-
folds = list(cv)
1126+
folds = list(cv.split(X))
11271127
best_mse = np.inf
11281128

11291129
# We do a double for loop folded in one, in order to be able to

sklearn/linear_model/least_angle.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .base import LinearModel
2323
from ..base import RegressorMixin
2424
from ..utils import arrayfuncs, as_float_array, check_X_y
25-
from ..cross_validation import check_cv
25+
from ..model_selection import check_cv
2626
from ..exceptions import ConvergenceWarning
2727
from ..externals.joblib import Parallel, delayed
2828
from ..externals.six.moves import xrange
@@ -1079,7 +1079,7 @@ def fit(self, X, y):
10791079
y = as_float_array(y, copy=self.copy_X)
10801080

10811081
# init cross-validation generator
1082-
cv = check_cv(self.cv, X, y, classifier=False)
1082+
cv = check_cv(self.cv, classifier=False)
10831083

10841084
Gram = 'auto' if self.precompute else None
10851085

@@ -1089,7 +1089,7 @@ def fit(self, X, y):
10891089
method=self.method, verbose=max(0, self.verbose - 1),
10901090
normalize=self.normalize, fit_intercept=self.fit_intercept,
10911091
max_iter=self.max_iter, eps=self.eps, positive=self.positive)
1092-
for train, test in cv)
1092+
for train, test in cv.split(X, y))
10931093
all_alphas = np.concatenate(list(zip(*cv_paths))[0])
10941094
# Unique also sorts
10951095
all_alphas = np.unique(all_alphas)

sklearn/linear_model/logistic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
"""
23
Logistic Regression
34
"""
@@ -32,7 +33,7 @@
3233
from ..utils.fixes import expit
3334
from ..utils.multiclass import check_classification_targets
3435
from ..externals.joblib import Parallel, delayed
35-
from ..cross_validation import check_cv
36+
from ..model_selection import check_cv
3637
from ..externals import six
3738
from ..metrics import SCORERS
3839

@@ -1309,7 +1310,7 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator,
13091310
cv : integer or cross-validation generator
13101311
The default cross-validation generator used is Stratified K-Folds.
13111312
If an integer is provided, then it is the number of folds used.
1312-
See the module :mod:`sklearn.cross_validation` module for the
1313+
See the module :mod:`sklearn.model_selection` module for the
13131314
list of possible cross-validation objects.
13141315
13151316
penalty : str, 'l1' or 'l2'
@@ -1506,8 +1507,8 @@ def fit(self, X, y, sample_weight=None):
15061507
check_consistent_length(X, y)
15071508

15081509
# init cross-validation generator
1509-
cv = check_cv(self.cv, X, y, classifier=True)
1510-
folds = list(cv)
1510+
cv = check_cv(self.cv, y, classifier=True)
1511+
folds = list(cv.split(X, y))
15111512

15121513
self._enc = LabelEncoder()
15131514
self._enc.fit(y)

0 commit comments

Comments
 (0)
0