From 482f7ff62b571dd207aa3385af8cf49fc55484d9 Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Wed, 13 Jun 2018 11:58:00 -0500 Subject: [PATCH 1/7] MAINT: add fit kwarg to cross validation functions --- doc/whats_new/v0.20.rst | 5 ++ sklearn/model_selection/_validation.py | 55 +++++++++++++++---- .../model_selection/tests/test_validation.py | 34 +++++++++++- 3 files changed, 82 insertions(+), 12 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index da5141efb245d..92bbbe05e024d 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -109,6 +109,11 @@ Model evaluation ``'balanced_accuracy'`` scorer for binary classification. :issue:`8066` by :user:`xyguo` and :user:`Aman Dalmia `. +- A ``fit`` keyword argument has been added to + :func:`model_selection.cross_val_score` and + :func:`model_selection.cross_validate` to control calling ``estimator.fit`` + before scoring. See :issue:`` by :user:`Scott Sievert `. + Decomposition, manifold learning and clustering - :class:`cluster.AgglomerativeClustering` now supports Single Linkage diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 50af9b5dd5504..b5731ef71294b 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -40,7 +40,7 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, pre_dispatch='2*n_jobs', return_train_score="warn", - return_estimator=False): + return_estimator=False, fit=None): """Evaluate metric(s) by cross-validation and also record fit/score times. Read more in the :ref:`User Guide `. @@ -133,6 +133,13 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, return_estimator : boolean, default False Whether to return the estimators fitted on each split. + fit : callable, str, optional + If None (default), call ``estimator.fit`` before scoring the model. + If callable, call ``fit(est, X, y, **fit_params)``. If + ``fit=='partial_fit'``, call ``estimator.partial_fit`` before scoring. + If ``fit`` is not None, the estimator is assumed to be pickleable. + + Returns ------- scores : dict of float arrays of shape=(n_splits,) @@ -209,13 +216,16 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. + # We do not clonse the estimator if a custom fit function is supplied, + # implying that the estimator has been trained. parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) scores = parallel( delayed(_fit_and_score)( - clone(estimator), X, y, scorers, train, test, verbose, None, - fit_params, return_train_score=return_train_score, - return_times=True, return_estimator=return_estimator) + clone(estimator) if fit is None else estimator, X, y, scorers, + train, test, verbose, None, fit_params, + return_train_score=return_train_score, return_times=True, + return_estimator=return_estimator, fit=fit) for train, test in cv.split(X, y, groups)) zipped_scores = list(zip(*scores)) @@ -254,7 +264,7 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, - pre_dispatch='2*n_jobs'): + pre_dispatch='2*n_jobs', fit=None): """Evaluate a score by cross-validation Read more in the :ref:`User Guide `. @@ -323,6 +333,12 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, - A string, giving an expression as a function of n_jobs, as in '2*n_jobs' + fit : callable, str, optional + If None (default), call ``estimator.fit`` before scoring the model. + If callable, call ``fit`` before scoring (which assumes the model can + be pickled). If ``fit=='partial_fit'``, call ``estimator.partial_fit`` + before scoring. + Returns ------- scores : array of float, shape=(len(list(cv)),) @@ -361,7 +377,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, return_train_score=False, n_jobs=n_jobs, verbose=verbose, fit_params=fit_params, - pre_dispatch=pre_dispatch) + pre_dispatch=pre_dispatch, fit=fit) return cv_results['test_score'] @@ -369,7 +385,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score=False, return_parameters=False, return_n_test_samples=False, return_times=False, return_estimator=False, - error_score='raise-deprecating'): + error_score='raise-deprecating', fit=None): """Fit estimator and compute scores for a given dataset split. Parameters @@ -431,6 +447,12 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, return_estimator : boolean, optional, default: False Whether to return the fitted estimator. + fit : callable, str, optional + If None (default), call ``estimator.fit`` before scoring the model. + If callable, call ``fit(est, X, y, **fit_params)``. If + ``fit=='partial_fit'``, call ``estimator.partial_fit`` before scoring. + If ``fit`` is not None, the estimator is assumed to be pickleable. + Returns ------- train_scores : dict of scorer name -> float, optional @@ -469,7 +491,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, for k, v in fit_params.items()]) train_scores = {} - if parameters is not None: + if parameters is not None and fit is None: estimator.set_params(**parameters) start_time = time.time() @@ -481,10 +503,21 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, n_scorers = len(scorer.keys()) if is_multimetric else 1 try: - if y_train is None: - estimator.fit(X_train, **fit_params) + fn = 'fit' if fit is None else fit + if isinstance(fn, str) and fn in {'fit', 'partial_fit'}: + if y_train is None: + getattr(estimator, fn)(X_train, **fit_params) + else: + getattr(estimator, fn)(X_train, y_train, **fit_params) + elif callable(fn): + if y_train is None: + fit(estimator, X_train, **fit_params) + else: + fit(estimator, X_train, y_train, **fit_params) else: - estimator.fit(X_train, y_train, **fit_params) + msg = ('keyword argument fit="{fit}" not recognized. fit should ' + 'be "fit", "partial_fit" or callable.') + raise ValueError(msg.format(fit=fit)) except Exception as e: # Note fit time as time until error diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index a537b9f53518a..c9bbbfa7b5887 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -10,7 +10,7 @@ import pytest import numpy as np from scipy.sparse import coo_matrix, csr_matrix -from sklearn.exceptions import FitFailedWarning +from sklearn.exceptions import FitFailedWarning, NotFittedError from sklearn.tests.test_grid_search import FailingClassifier @@ -618,6 +618,38 @@ def assert_fit_params(clf): cross_val_score(clf, X, y, fit_params=fit_params) +@pytest.mark.parametrize("fit", [ + lambda est, X, y, **fit_params: est.partial_fit(X, y, **fit_params), + 'partial_fit', +]) +def test_cross_validate_fit_kwarg(fit): + X, y = make_classification(n_samples=20, n_classes=2, random_state=0) + classes = np.unique(y) + + with warnings.catch_warnings(record=True): + clf = SGDClassifier(random_state=0) + clf2 = SGDClassifier(random_state=0) + with pytest.raises(NotFittedError): + cross_validate(clf, X, y, fit=lambda est, *args, **kwargs: est) + + # repeat to make sure estimator is still pickleable + for _ in range(5): + scores_no_fit = cross_validate(clf, X, y, fit=fit, + fit_params={'classes': classes}) + scores_fit = cross_validate(clf2, X, y) + + assert_true(set(scores_no_fit.keys()) == set(scores_fit.keys())) + + +def test_cross_validate_fit_kwarg_raises(): + clf = SGDClassifier(random_state=0) + X, y = make_classification(n_samples=20, n_classes=2, random_state=0) + classes = np.unique(y) + with warnings.catch_warnings(record=True): + with pytest.raises(ValueError, match='fit should be'): + cross_validate(clf, X, y, fit='foo') + + def test_cross_val_score_score_func(): clf = MockClassifier() _score_func_args = [] From 67376a4f6150ee1e655e933909ee44c7c3e28768 Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Wed, 13 Jun 2018 15:55:15 -0500 Subject: [PATCH 2/7] BUG: make sure iterations increase as expected --- sklearn/model_selection/_validation.py | 70 +++++++++---------- .../model_selection/tests/test_validation.py | 57 ++++++++------- 2 files changed, 68 insertions(+), 59 deletions(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index b5731ef71294b..5b0e18a2e7c0f 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -31,6 +31,7 @@ from ..exceptions import FitFailedWarning from ._split import check_cv from ..preprocessing import LabelEncoder +from copy import deepcopy __all__ = ['cross_validate', 'cross_val_score', 'cross_val_predict', @@ -40,7 +41,7 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, pre_dispatch='2*n_jobs', return_train_score="warn", - return_estimator=False, fit=None): + return_estimator=False, partial_fit=False): """Evaluate metric(s) by cross-validation and also record fit/score times. Read more in the :ref:`User Guide `. @@ -133,11 +134,11 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, return_estimator : boolean, default False Whether to return the estimators fitted on each split. - fit : callable, str, optional - If None (default), call ``estimator.fit`` before scoring the model. - If callable, call ``fit(est, X, y, **fit_params)``. If - ``fit=='partial_fit'``, call ``estimator.partial_fit`` before scoring. - If ``fit`` is not None, the estimator is assumed to be pickleable. + partial_fit : boolean, integer, default False + If False (default), call ``fit``. If True, call + ``estimator.partial_fit`` once. If an integer, call ``partial_fit`` + times. ``estimator`` is assumed to be pickleable if ``partial_fit`` + is not True. Returns @@ -216,16 +217,16 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. - # We do not clonse the estimator if a custom fit function is supplied, + # We do not clone the estimator if a custom fit function is supplied, # implying that the estimator has been trained. parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) scores = parallel( delayed(_fit_and_score)( - clone(estimator) if fit is None else estimator, X, y, scorers, + clone(estimator) if not partial_fit else deepcopy(estimator), X, y, scorers, train, test, verbose, None, fit_params, return_train_score=return_train_score, return_times=True, - return_estimator=return_estimator, fit=fit) + return_estimator=return_estimator, partial_fit=partial_fit) for train, test in cv.split(X, y, groups)) zipped_scores = list(zip(*scores)) @@ -264,7 +265,7 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, - pre_dispatch='2*n_jobs', fit=None): + pre_dispatch='2*n_jobs', partial_fit=False): """Evaluate a score by cross-validation Read more in the :ref:`User Guide `. @@ -333,11 +334,11 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, - A string, giving an expression as a function of n_jobs, as in '2*n_jobs' - fit : callable, str, optional - If None (default), call ``estimator.fit`` before scoring the model. - If callable, call ``fit`` before scoring (which assumes the model can - be pickled). If ``fit=='partial_fit'``, call ``estimator.partial_fit`` - before scoring. + partial_fit : boolean, integer, default False + If False (default), call ``fit``. If True, call + ``estimator.partial_fit`` once. If an integer, call ``partial_fit`` + times. ``estimator`` is assumed to be pickleable if ``partial_fit`` + is not True. Returns ------- @@ -377,7 +378,8 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, return_train_score=False, n_jobs=n_jobs, verbose=verbose, fit_params=fit_params, - pre_dispatch=pre_dispatch, fit=fit) + pre_dispatch=pre_dispatch, + partial_fit=partial_fit) return cv_results['test_score'] @@ -385,7 +387,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score=False, return_parameters=False, return_n_test_samples=False, return_times=False, return_estimator=False, - error_score='raise-deprecating', fit=None): + error_score='raise-deprecating', partial_fit=False): """Fit estimator and compute scores for a given dataset split. Parameters @@ -447,11 +449,11 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, return_estimator : boolean, optional, default: False Whether to return the fitted estimator. - fit : callable, str, optional - If None (default), call ``estimator.fit`` before scoring the model. - If callable, call ``fit(est, X, y, **fit_params)``. If - ``fit=='partial_fit'``, call ``estimator.partial_fit`` before scoring. - If ``fit`` is not None, the estimator is assumed to be pickleable. + partial_fit : boolean, integer, default False + If False (default), call ``fit``. If True, call + ``estimator.partial_fit`` once. If an integer, call ``partial_fit`` + times. ``estimator`` is assumed to be pickleable if ``partial_fit`` + is not True. Returns ------- @@ -491,7 +493,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, for k, v in fit_params.items()]) train_scores = {} - if parameters is not None and fit is None: + if parameters is not None and not partial_fit: estimator.set_params(**parameters) start_time = time.time() @@ -503,21 +505,19 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, n_scorers = len(scorer.keys()) if is_multimetric else 1 try: - fn = 'fit' if fit is None else fit - if isinstance(fn, str) and fn in {'fit', 'partial_fit'}: + if not isinstance(partial_fit, (bool, int)): + raise ValueError('partial_fit must be a boolean or an integer') + if isinstance(partial_fit, bool) and not partial_fit: if y_train is None: - getattr(estimator, fn)(X_train, **fit_params) + estimator.fit(X_train, **fit_params) else: - getattr(estimator, fn)(X_train, y_train, **fit_params) - elif callable(fn): - if y_train is None: - fit(estimator, X_train, **fit_params) - else: - fit(estimator, X_train, y_train, **fit_params) + estimator.fit(X_train, y_train, **fit_params) else: - msg = ('keyword argument fit="{fit}" not recognized. fit should ' - 'be "fit", "partial_fit" or callable.') - raise ValueError(msg.format(fit=fit)) + for _ in range(int(partial_fit)): + if y_train is None: + estimator.partial_fit(X_train, **fit_params) + else: + estimator.partial_fit(X_train, y_train, **fit_params) except Exception as e: # Note fit time as time until error diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index c9bbbfa7b5887..b5cf64056321e 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -618,36 +618,45 @@ def assert_fit_params(clf): cross_val_score(clf, X, y, fit_params=fit_params) -@pytest.mark.parametrize("fit", [ - lambda est, X, y, **fit_params: est.partial_fit(X, y, **fit_params), - 'partial_fit', -]) -def test_cross_validate_fit_kwarg(fit): +@pytest.mark.parametrize("partial_fit", [True, 4, 8]) +def test_cross_validate_fit_kwarg(partial_fit): X, y = make_classification(n_samples=20, n_classes=2, random_state=0) classes = np.unique(y) - with warnings.catch_warnings(record=True): - clf = SGDClassifier(random_state=0) - clf2 = SGDClassifier(random_state=0) - with pytest.raises(NotFittedError): - cross_validate(clf, X, y, fit=lambda est, *args, **kwargs: est) - - # repeat to make sure estimator is still pickleable - for _ in range(5): - scores_no_fit = cross_validate(clf, X, y, fit=fit, - fit_params={'classes': classes}) - scores_fit = cross_validate(clf2, X, y) - - assert_true(set(scores_no_fit.keys()) == set(scores_fit.keys())) - - -def test_cross_validate_fit_kwarg_raises(): + tol = -np.inf + clf_p_fit = SGDClassifier(random_state=0, tol=tol, max_iter=10) + clf_fit = SGDClassifier(random_state=0, tol=tol, max_iter=100) + + cv = 2 + # scores_partial_fit + scores_p_fit = cross_validate(clf_p_fit, X, y, partial_fit=partial_fit, + fit_params={'classes': classes}, + return_estimator=True, cv=cv) + # score_fit + scores_fit = cross_validate(clf_fit, X, y, return_estimator=True, cv=cv) + assert_true(set(scores_p_fit.keys()) == set(scores_fit.keys())) + + clfs_p_fit = scores_p_fit.pop('estimator') + clfs_fit = scores_fit.pop('estimator') + for clf_fit, clf_p_fit in zip(clfs_fit, clfs_p_fit): + assert_true(clf_p_fit.t_ * 10 < clf_fit.t_) + assert_true(clf_p_fit.t_ - 1 == + int(partial_fit * X.shape[0] * (cv - 1) / cv)) + + +@pytest.mark.parametrize("partial_fit", ['foo', 1.0, 1, True]) +def test_cross_validate_fit_kwarg_raises(partial_fit): clf = SGDClassifier(random_state=0) X, y = make_classification(n_samples=20, n_classes=2, random_state=0) classes = np.unique(y) - with warnings.catch_warnings(record=True): - with pytest.raises(ValueError, match='fit should be'): - cross_validate(clf, X, y, fit='foo') + + if isinstance(partial_fit, (bool, int)): + cross_validate(clf, X, y, partial_fit=partial_fit, + fit_params={'classes': classes}) + else: + with pytest.raises(ValueError, match='partial_fit must be'): + cross_validate(clf, X, y, partial_fit=partial_fit, + fit_params={'classes': classes}) def test_cross_val_score_score_func(): From 4ab23e61d52d7342bcafcf7c7ba660005e94d2f8 Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Thu, 14 Jun 2018 17:21:19 -0500 Subject: [PATCH 3/7] ENH: add val_data param to cross_validate --- sklearn/model_selection/_validation.py | 63 ++++++++++--------- .../model_selection/tests/test_validation.py | 17 +++++ 2 files changed, 50 insertions(+), 30 deletions(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 5b0e18a2e7c0f..b8772c150fc12 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -32,6 +32,7 @@ from ._split import check_cv from ..preprocessing import LabelEncoder from copy import deepcopy +from ..base import BaseEstimator __all__ = ['cross_validate', 'cross_val_score', 'cross_val_predict', @@ -41,7 +42,7 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, pre_dispatch='2*n_jobs', return_train_score="warn", - return_estimator=False, partial_fit=False): + return_estimator=False, partial_fit=False, test_data=None): """Evaluate metric(s) by cross-validation and also record fit/score times. Read more in the :ref:`User Guide `. @@ -215,19 +216,16 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, cv = check_cv(cv, y, classifier=is_classifier(estimator)) scorers, _ = _check_multimetric_scoring(estimator, scoring=scoring) - # We clone the estimator to make sure that all the folds are - # independent, and that it is pickle-able. - # We do not clone the estimator if a custom fit function is supplied, - # implying that the estimator has been trained. + ests = _get_estimators(estimator, cv.get_n_splits(X, y, groups)) parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) scores = parallel( delayed(_fit_and_score)( - clone(estimator) if not partial_fit else deepcopy(estimator), X, y, scorers, - train, test, verbose, None, fit_params, + est, X, y, scorers, train, test, verbose, None, fit_params, return_train_score=return_train_score, return_times=True, - return_estimator=return_estimator, partial_fit=partial_fit) - for train, test in cv.split(X, y, groups)) + return_estimator=return_estimator, partial_fit=partial_fit, + test_data=test_data) + for est, (train, test) in zip(ests, cv.split(X, y, groups))) zipped_scores = list(zip(*scores)) if return_train_score: @@ -263,9 +261,20 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, return ret +def _get_estimators(estimator, n): + if isinstance(estimator, (list, tuple)): + ests = (e for e in estimator) + if len(ests) != n: + msg = ('the number of estimators ({n_ests}) being fit is not equal to ' + 'the number of splits for cross validation ({n_splits}).') + raise ValueError(msg.format(n_ests=len(ests), n_splits=cv.n_splits)) + return ests + return (clone(estimator) for _ in range(n)) + + def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, - pre_dispatch='2*n_jobs', partial_fit=False): + pre_dispatch='2*n_jobs'): """Evaluate a score by cross-validation Read more in the :ref:`User Guide `. @@ -334,12 +343,6 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, - A string, giving an expression as a function of n_jobs, as in '2*n_jobs' - partial_fit : boolean, integer, default False - If False (default), call ``fit``. If True, call - ``estimator.partial_fit`` once. If an integer, call ``partial_fit`` - times. ``estimator`` is assumed to be pickleable if ``partial_fit`` - is not True. - Returns ------- scores : array of float, shape=(len(list(cv)),) @@ -378,8 +381,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, return_train_score=False, n_jobs=n_jobs, verbose=verbose, fit_params=fit_params, - pre_dispatch=pre_dispatch, - partial_fit=partial_fit) + pre_dispatch=pre_dispatch) return cv_results['test_score'] @@ -387,7 +389,8 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score=False, return_parameters=False, return_n_test_samples=False, return_times=False, return_estimator=False, - error_score='raise-deprecating', partial_fit=False): + error_score='raise-deprecating', partial_fit=False, + test_data=None): """Fit estimator and compute scores for a given dataset split. Parameters @@ -498,8 +501,14 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, start_time = time.time() - X_train, y_train = _safe_split(estimator, X, y, train) - X_test, y_test = _safe_split(estimator, X, y, test, train) + if test_data is None: + X_train, y_train = _safe_split(estimator, X, y, train) + X_test, y_test = _safe_split(estimator, X, y, test, train) + else: + X_train, y_train = X, y + X_test = test_data[0] + y_test = test_data[1] if y_train is not None else None + train_data = (X_train, y_train) if y_train is not None else (X_train, ) is_multimetric = not callable(scorer) n_scorers = len(scorer.keys()) if is_multimetric else 1 @@ -507,17 +516,11 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, try: if not isinstance(partial_fit, (bool, int)): raise ValueError('partial_fit must be a boolean or an integer') - if isinstance(partial_fit, bool) and not partial_fit: - if y_train is None: - estimator.fit(X_train, **fit_params) - else: - estimator.fit(X_train, y_train, **fit_params) + if not partial_fit: + estimator.fit(*train_data, **fit_params) else: for _ in range(int(partial_fit)): - if y_train is None: - estimator.partial_fit(X_train, **fit_params) - else: - estimator.partial_fit(X_train, y_train, **fit_params) + estimator.partial_fit(*train_data, **fit_params) except Exception as e: # Note fit time as time until error diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index b5cf64056321e..d1408a3fd71b9 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -659,6 +659,23 @@ def test_cross_validate_fit_kwarg_raises(partial_fit): fit_params={'classes': classes}) +def test_cross_validate_val_set(): + n, d = 100, 2 + X_train, y_train = make_classification(n_samples=n, n_classes=2, + n_features=d, random_state=0, + n_redundant=0, n_informative=d) + rng = np.random.RandomState(0) + X_test = rng.randn(n, d) + y_test = (np.sign(rng.randn(n)) + 1) / 2 + + clf = SGDClassifier(random_state=0) + ret = cross_validate(clf, X_train, y_train, val_data=(X_test, y_test)) + + assert_true(ret['test_score'].max() < ret['train_score'].min()) + assert_true(ret['test_score'].max() < 0.48) + assert_true(0.85 < ret['train_score'].min()) + + def test_cross_val_score_score_func(): clf = MockClassifier() _score_func_args = [] From b96d3df6b097c3feb5c8e8c3c990c9fc054d1059 Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Thu, 14 Jun 2018 18:47:28 -0500 Subject: [PATCH 4/7] TST: Make sure repeated calls work --- sklearn/model_selection/_validation.py | 20 +++++---- .../model_selection/tests/test_validation.py | 44 ++++++++++++++++--- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index b8772c150fc12..c1bd92306159a 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -214,9 +214,9 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) - scorers, _ = _check_multimetric_scoring(estimator, scoring=scoring) - ests = _get_estimators(estimator, cv.get_n_splits(X, y, groups)) + scorers, _ = _check_multimetric_scoring(ests[0], scoring=scoring) + parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) scores = parallel( @@ -262,14 +262,18 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, def _get_estimators(estimator, n): - if isinstance(estimator, (list, tuple)): - ests = (e for e in estimator) - if len(ests) != n: + if isinstance(estimator, tuple): + estimator = list(estimator) + if isinstance(estimator, list): + ret = estimator + if len(ret) != n: msg = ('the number of estimators ({n_ests}) being fit is not equal to ' 'the number of splits for cross validation ({n_splits}).') - raise ValueError(msg.format(n_ests=len(ests), n_splits=cv.n_splits)) - return ests - return (clone(estimator) for _ in range(n)) + raise ValueError(msg.format(n_ests=len(ests), n_splits=n)) + return ret + else: + ret = [clone(estimator) for _ in range(n)] + return ret def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index d1408a3fd71b9..399dd81a5af04 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -10,7 +10,7 @@ import pytest import numpy as np from scipy.sparse import coo_matrix, csr_matrix -from sklearn.exceptions import FitFailedWarning, NotFittedError +from sklearn.exceptions import FitFailedWarning from sklearn.tests.test_grid_search import FailingClassifier @@ -636,8 +636,8 @@ def test_cross_validate_fit_kwarg(partial_fit): scores_fit = cross_validate(clf_fit, X, y, return_estimator=True, cv=cv) assert_true(set(scores_p_fit.keys()) == set(scores_fit.keys())) - clfs_p_fit = scores_p_fit.pop('estimator') - clfs_fit = scores_fit.pop('estimator') + clfs_p_fit = scores_p_fit['estimator'] + clfs_fit = scores_fit['estimator'] for clf_fit, clf_p_fit in zip(clfs_fit, clfs_p_fit): assert_true(clf_p_fit.t_ * 10 < clf_fit.t_) assert_true(clf_p_fit.t_ - 1 == @@ -669,11 +669,41 @@ def test_cross_validate_val_set(): y_test = (np.sign(rng.randn(n)) + 1) / 2 clf = SGDClassifier(random_state=0) - ret = cross_validate(clf, X_train, y_train, val_data=(X_test, y_test)) + r = cross_validate(clf, X_train, y_train, test_data=(X_test, y_test)) - assert_true(ret['test_score'].max() < ret['train_score'].min()) - assert_true(ret['test_score'].max() < 0.48) - assert_true(0.85 < ret['train_score'].min()) + assert_true(r['test_score'].max() < 0.48 < 0.85 < r['train_score'].min()) + + +def test_cross_validate_repeated_call(): + n, d = 100, 80 + cv = 3 + X, y = make_classification(n_samples=n, n_features=d, n_classes=2, + random_state=0, n_redundant=0, + n_informative=2) + classes = np.unique(y) + one_epoch = X.shape[0] * (cv - 1) / cv + + clf = SGDClassifier(random_state=0) + ret1 = cross_validate(clf, X, y, fit_params={'classes': classes}, + return_estimator=True, partial_fit=True, cv=cv) + iters1 = [(est.t_ - 1) / one_epoch for est in ret1['estimator']] + assert isinstance(ret1['estimator'], tuple) + assert all([isinstance(e, BaseEstimator) for e in ret1['estimator']]) + + ret2 = cross_validate(ret1['estimator'], X, y, return_estimator=True, partial_fit=True, + cv=cv) + + assert set(ret1.keys()) == set(ret2.keys()) + for k, v1 in ret1.items(): + v2 = ret2[k] + assert len(v1) == len(v2) + if k == 'train_score': + assert v1.mean() < 0.90 < 0.93 < v2.mean() + if k == 'test_score': + assert v1.mean() < 0.73 < 0.75 < v2.mean() + + iters2 = [(est.t_ - 1) / one_epoch for est in ret2['estimator']] + assert sum(iters1) / cv == 1 and sum(iters2) / cv == 2 def test_cross_val_score_score_func(): From 1059e453c6c230c9ea51330bd470dd95ed563ba5 Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Fri, 15 Jun 2018 10:56:38 -0500 Subject: [PATCH 5/7] MAINT: rename val_data to X_test and y_test --- sklearn/model_selection/_validation.py | 18 +++++++----------- .../model_selection/tests/test_validation.py | 8 ++++---- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index c1bd92306159a..c063e27d08266 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -31,8 +31,6 @@ from ..exceptions import FitFailedWarning from ._split import check_cv from ..preprocessing import LabelEncoder -from copy import deepcopy -from ..base import BaseEstimator __all__ = ['cross_validate', 'cross_val_score', 'cross_val_predict', @@ -42,7 +40,8 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, pre_dispatch='2*n_jobs', return_train_score="warn", - return_estimator=False, partial_fit=False, test_data=None): + return_estimator=False, partial_fit=False, X_test=None, + y_test=None): """Evaluate metric(s) by cross-validation and also record fit/score times. Read more in the :ref:`User Guide `. @@ -224,7 +223,7 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, est, X, y, scorers, train, test, verbose, None, fit_params, return_train_score=return_train_score, return_times=True, return_estimator=return_estimator, partial_fit=partial_fit, - test_data=test_data) + X_test=X_test, y_test=y_test) for est, (train, test) in zip(ests, cv.split(X, y, groups))) zipped_scores = list(zip(*scores)) @@ -267,9 +266,9 @@ def _get_estimators(estimator, n): if isinstance(estimator, list): ret = estimator if len(ret) != n: - msg = ('the number of estimators ({n_ests}) being fit is not equal to ' + msg = ('the number of estimators ({n}) being fit is not equal to ' 'the number of splits for cross validation ({n_splits}).') - raise ValueError(msg.format(n_ests=len(ests), n_splits=n)) + raise ValueError(msg.format(n=len(ret), n_splits=n)) return ret else: ret = [clone(estimator) for _ in range(n)] @@ -394,7 +393,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, return_parameters=False, return_n_test_samples=False, return_times=False, return_estimator=False, error_score='raise-deprecating', partial_fit=False, - test_data=None): + X_test=None, y_test=None): """Fit estimator and compute scores for a given dataset split. Parameters @@ -505,13 +504,11 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, start_time = time.time() - if test_data is None: + if X_test is None and y_test is None: X_train, y_train = _safe_split(estimator, X, y, train) X_test, y_test = _safe_split(estimator, X, y, test, train) else: X_train, y_train = X, y - X_test = test_data[0] - y_test = test_data[1] if y_train is not None else None train_data = (X_train, y_train) if y_train is not None else (X_train, ) is_multimetric = not callable(scorer) @@ -525,7 +522,6 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, else: for _ in range(int(partial_fit)): estimator.partial_fit(*train_data, **fit_params) - except Exception as e: # Note fit time as time until error fit_time = time.time() - start_time diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 399dd81a5af04..5635a1947d2fb 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -669,9 +669,9 @@ def test_cross_validate_val_set(): y_test = (np.sign(rng.randn(n)) + 1) / 2 clf = SGDClassifier(random_state=0) - r = cross_validate(clf, X_train, y_train, test_data=(X_test, y_test)) + r = cross_validate(clf, X_train, y_train, X_test=X_test, y_test=y_test) - assert_true(r['test_score'].max() < 0.48 < 0.85 < r['train_score'].min()) + assert_true(r['test_score'].mean() < 0.48 < 0.85 < r['train_score'].mean()) def test_cross_validate_repeated_call(): @@ -690,8 +690,8 @@ def test_cross_validate_repeated_call(): assert isinstance(ret1['estimator'], tuple) assert all([isinstance(e, BaseEstimator) for e in ret1['estimator']]) - ret2 = cross_validate(ret1['estimator'], X, y, return_estimator=True, partial_fit=True, - cv=cv) + ret2 = cross_validate(ret1['estimator'], X, y, return_estimator=True, + partial_fit=True, cv=cv) assert set(ret1.keys()) == set(ret2.keys()) for k, v1 in ret1.items(): From 32fb4816b85ce97aa8ea23a7fa2dde6f5f7a0df8 Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Fri, 15 Jun 2018 11:22:41 -0500 Subject: [PATCH 6/7] MAINT: flake8 passes --- sklearn/model_selection/_split.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 866cb4cc53aa8..202009ab4951a 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -2058,6 +2058,7 @@ def train_test_split(*arrays, **options): # Tell nose that train_test_split is not a test train_test_split.__test__ = False + def _build_repr(self): # XXX This is copied from BaseEstimator's get_params cls = self.__class__ From 0723dc9595e7b3b3258b00ba1bb1474d3dea54ca Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Fri, 15 Jun 2018 11:36:42 -0500 Subject: [PATCH 7/7] DOC: better document --- doc/whats_new/v0.20.rst | 9 +++++---- sklearn/model_selection/_validation.py | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 92bbbe05e024d..f7dc2fba5538b 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -109,10 +109,11 @@ Model evaluation ``'balanced_accuracy'`` scorer for binary classification. :issue:`8066` by :user:`xyguo` and :user:`Aman Dalmia `. -- A ``fit`` keyword argument has been added to - :func:`model_selection.cross_val_score` and - :func:`model_selection.cross_validate` to control calling ``estimator.fit`` - before scoring. See :issue:`` by :user:`Scott Sievert `. +- Keyword arguments have been added to :func:`model_selection.cross_validate` + to control calling ``estimator.partial_fit`` before scoring and to control + the validation set. These are implemented with the keyword arguments + ``partial_fit``, ``X_test`` and ``y_test`` in :issue:`11266` by + :user:`Scott Sievert `. Decomposition, manifold learning and clustering diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index c063e27d08266..d0dd71dcc3b6e 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -48,8 +48,10 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, Parameters ---------- - estimator : estimator object implementing 'fit' - The object to use to fit the data. + estimator : estimator object implementing 'fit', list of estimators + The object to use to fit the data. If a list, do not clone each + estimator and it must be the same length as the number of cross + validation splits. X : array-like The data to fit. Can be for example a list, or an array. @@ -140,6 +142,14 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, times. ``estimator`` is assumed to be pickleable if ``partial_fit`` is not True. + X_test : array-like, optional + If present, treat this as the validation set and use + ``X`` and ``y`` for training as the training set. + + y_test : array-like, optional + If present, treat this as the validation set and use + ``X`` and ``y`` for training as the training set. + Returns -------