diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst index 0ada6ef6c4dbe..27000a192ab21 100644 --- a/doc/metadata_routing.rst +++ b/doc/metadata_routing.rst @@ -292,6 +292,7 @@ Meta-estimators and functions supporting metadata routing: - :class:`sklearn.linear_model.LogisticRegressionCV` - :class:`sklearn.linear_model.MultiTaskElasticNetCV` - :class:`sklearn.linear_model.MultiTaskLassoCV` +- :class:`sklearn.linear_model.OrthogonalMatchingPursuitCV` - :class:`sklearn.linear_model.RANSACRegressor` - :class:`sklearn.linear_model.RidgeClassifierCV` - :class:`sklearn.linear_model.RidgeCV` @@ -302,13 +303,13 @@ Meta-estimators and functions supporting metadata routing: - :func:`sklearn.model_selection.cross_validate` - :func:`sklearn.model_selection.cross_val_score` - :func:`sklearn.model_selection.cross_val_predict` +- :class:`sklearn.model_selection.learning_curve` - :class:`sklearn.multiclass.OneVsOneClassifier` - :class:`sklearn.multiclass.OneVsRestClassifier` - :class:`sklearn.multiclass.OutputCodeClassifier` - :class:`sklearn.multioutput.ClassifierChain` - :class:`sklearn.multioutput.MultiOutputClassifier` - :class:`sklearn.multioutput.MultiOutputRegressor` -- :class:`sklearn.linear_model.OrthogonalMatchingPursuitCV` - :class:`sklearn.multioutput.RegressorChain` - :class:`sklearn.pipeline.FeatureUnion` - :class:`sklearn.pipeline.Pipeline` @@ -321,7 +322,6 @@ Meta-estimators and tools not supporting metadata routing yet: - :class:`sklearn.feature_selection.RFE` - :class:`sklearn.feature_selection.RFECV` - :class:`sklearn.feature_selection.SequentialFeatureSelector` -- :class:`sklearn.model_selection.learning_curve` - :class:`sklearn.model_selection.permutation_test_score` - :class:`sklearn.model_selection.validation_curve` - :class:`sklearn.semi_supervised.SelfTrainingClassifier` diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 5000866b59c03..30e9e14d6c6df 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -47,6 +47,10 @@ The following models now support metadata routing in one or more of their methods. Refer to the :ref:`Metadata Routing User Guide ` for more details. +- |Feature| :func:`model_selection.learning_curve` now supports metadata routing for the + `fit` method of its estimator and for its underlying CV splitter and scorer. + :pr:`28975` by :user:`Stefanie Senger `. + - |Feature| :class:`ensemble.StackingClassifier` and :class:`ensemble.StackingRegressor` now support metadata routing and pass ``**fit_params`` to the underlying estimators via their `fit` methods. diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 176627ace91d4..83d289d36efb2 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -58,21 +58,22 @@ ] -def _check_params_groups_deprecation(fit_params, params, groups): +def _check_params_groups_deprecation(fit_params, params, groups, version): """A helper function to check deprecations on `groups` and `fit_params`. - To be removed when set_config(enable_metadata_routing=False) is not possible. + # TODO(SLEP6): To be removed when set_config(enable_metadata_routing=False) is not + # possible. """ if params is not None and fit_params is not None: raise ValueError( "`params` and `fit_params` cannot both be provided. Pass parameters " "via `params`. `fit_params` is deprecated and will be removed in " - "version 1.6." + f"version {version}." ) elif fit_params is not None: warnings.warn( ( - "`fit_params` is deprecated and will be removed in version 1.6. " + "`fit_params` is deprecated and will be removed in version {version}. " "Pass parameters via `params` instead." ), FutureWarning, @@ -346,7 +347,7 @@ def cross_validate( >>> print(scores['train_r2']) [0.28009951 0.3908844 0.22784907] """ - params = _check_params_groups_deprecation(fit_params, params, groups) + params = _check_params_groups_deprecation(fit_params, params, groups, "1.6") X, y = indexable(X, y) @@ -602,10 +603,8 @@ def cross_val_score( ``cross_val_score(..., params={'groups': groups})``. scoring : str or callable, default=None - A str (see model evaluation documentation) or - a scorer callable object / function with signature - ``scorer(estimator, X, y)`` which should return only - a single value. + A str (see :ref:`scoring_parameter`) or a scorer callable object / function with + signature ``scorer(estimator, X, y)`` which should return only a single value. Similar to :func:`cross_validate` but only a single metric is permitted. @@ -1206,7 +1205,7 @@ def cross_val_predict( >>> lasso = linear_model.Lasso() >>> y_pred = cross_val_predict(lasso, X, y, cv=3) """ - params = _check_params_groups_deprecation(fit_params, params, groups) + params = _check_params_groups_deprecation(fit_params, params, groups, "1.6") X, y = indexable(X, y) if _routing_enabled(): @@ -1718,6 +1717,7 @@ def _shuffle(y, groups, random_state): "error_score": [StrOptions({"raise"}), Real], "return_times": ["boolean"], "fit_params": [dict, None], + "params": [dict, None], }, prefer_skip_nested_validation=False, # estimator is not validated yet ) @@ -1739,6 +1739,7 @@ def learning_curve( error_score=np.nan, return_times=False, fit_params=None, + params=None, ): """Learning curve. @@ -1773,6 +1774,13 @@ def learning_curve( train/test set. Only used in conjunction with a "Group" :term:`cv` instance (e.g., :class:`GroupKFold`). + .. versionchanged:: 1.6 + ``groups`` can only be passed if metadata routing is not enabled + via ``sklearn.set_config(enable_metadata_routing=True)``. When routing + is enabled, pass ``groups`` alongside other metadata via the ``params`` + argument instead. E.g.: + ``learning_curve(..., params={'groups': groups})``. + train_sizes : array-like of shape (n_ticks,), \ default=np.linspace(0.1, 1.0, 5) Relative or absolute numbers of training examples that will be used to @@ -1780,7 +1788,7 @@ def learning_curve( fraction of the maximum size of the training set (that is determined by the selected validation method), i.e. it has to be within (0, 1]. Otherwise it is interpreted as absolute sizes of the training sets. - Note that for classification the number of samples usually have to + Note that for classification the number of samples usually has to be big enough to contain at least one sample from each class. cv : int, cross-validation generator or an iterable, default=None @@ -1804,9 +1812,8 @@ def learning_curve( ``cv`` default value if None changed from 3-fold to 5-fold. scoring : str or callable, default=None - A str (see model evaluation documentation) or - a scorer callable object / function with signature - ``scorer(estimator, X, y)``. + A str (see :ref:`scoring_parameter`) or a scorer callable object / function with + signature ``scorer(estimator, X, y)``. exploit_incremental_learning : bool, default=False If the estimator supports incremental learning, this will be @@ -1849,7 +1856,22 @@ def learning_curve( fit_params : dict, default=None Parameters to pass to the fit method of the estimator. - .. versionadded:: 0.24 + .. deprecated:: 1.6 + This parameter is deprecated and will be removed in version 1.6. Use + ``params`` instead. + + params : dict, default=None + Parameters to pass to the `fit` method of the estimator and to the scorer. + + - If `enable_metadata_routing=False` (default): + Parameters directly passed to the `fit` method of the estimator. + + - If `enable_metadata_routing=True`: + Parameters safely routed to the `fit` method of the estimator. + See :ref:`Metadata Routing User Guide ` for more + details. + + .. versionadded:: 1.6 Returns ------- @@ -1903,14 +1925,69 @@ def learning_curve( "An estimator must support the partial_fit interface " "to exploit incremental learning" ) + + params = _check_params_groups_deprecation(fit_params, params, groups, "1.8") + X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) - # Store it as list as we will be iterating over the list multiple times - cv_iter = list(cv.split(X, y, groups)) scorer = check_scoring(estimator, scoring=scoring) + if _routing_enabled(): + router = ( + MetadataRouter(owner="learning_curve") + .add( + estimator=estimator, + # TODO(SLEP6): also pass metadata to the predict method for + # scoring? + method_mapping=MethodMapping() + .add(caller="fit", callee="fit") + .add(caller="fit", callee="partial_fit"), + ) + .add( + splitter=cv, + method_mapping=MethodMapping().add(caller="fit", callee="split"), + ) + .add( + scorer=scorer, + method_mapping=MethodMapping().add(caller="fit", callee="score"), + ) + ) + + try: + routed_params = process_routing(router, "fit", **params) + except UnsetMetadataPassedError as e: + # The default exception would mention `fit` since in the above + # `process_routing` code, we pass `fit` as the caller. However, + # the user is not calling `fit` directly, so we change the message + # to make it more suitable for this case. + unrequested_params = sorted(e.unrequested_params) + raise UnsetMetadataPassedError( + message=( + f"{unrequested_params} are passed to `learning_curve` but are not" + " explicitly set as requested or not requested for learning_curve's" + f" estimator: {estimator.__class__.__name__}. Call" + " `.set_fit_request({{metadata}}=True)` on the estimator for" + f" each metadata in {unrequested_params} that you" + " want to use and `metadata=False` for not using it. See the" + " Metadata Routing User guide" + " for more" + " information." + ), + unrequested_params=e.unrequested_params, + routed_params=e.routed_params, + ) + + else: + routed_params = Bunch() + routed_params.estimator = Bunch(fit=params, partial_fit=params) + routed_params.splitter = Bunch(split={"groups": groups}) + routed_params.scorer = Bunch(score={}) + + # Store cv as list as we will be iterating over the list multiple times + cv_iter = list(cv.split(X, y, **routed_params.splitter.split)) + n_max_training_samples = len(cv_iter[0][0]) # Because the lengths of folds can be significantly different, it is # not guaranteed that we use all of the available training data when we @@ -1940,7 +2017,8 @@ def learning_curve( scorer, return_times, error_score=error_score, - fit_params=fit_params, + fit_params=routed_params.estimator.partial_fit, + score_params=routed_params.scorer.score, ) for train, test in cv_iter ) @@ -1961,9 +2039,8 @@ def learning_curve( test=test, verbose=verbose, parameters=None, - fit_params=fit_params, - # TODO(SLEP6): support score params here - score_params=None, + fit_params=routed_params.estimator.fit, + score_params=routed_params.scorer.score, return_train_score=True, error_score=error_score, return_times=return_times, @@ -2069,6 +2146,7 @@ def _incremental_fit_estimator( return_times, error_score, fit_params, + score_params, ): """Train estimator on training subsets incrementally and compute scores.""" train_scores, test_scores, fit_times, score_times = [], [], [], [] @@ -2079,6 +2157,9 @@ def _incremental_fit_estimator( partial_fit_func = partial(estimator.partial_fit, **fit_params) else: partial_fit_func = partial(estimator.partial_fit, classes=classes, **fit_params) + score_params = score_params if score_params is not None else {} + score_params_train = _check_method_params(X, params=score_params, indices=train) + score_params_test = _check_method_params(X, params=score_params, indices=test) for n_train_samples, partial_train in partitions: train_subset = train[:n_train_samples] @@ -2095,14 +2176,13 @@ def _incremental_fit_estimator( start_score = time.time() - # TODO(SLEP6): support score params in the following two calls test_scores.append( _score( estimator, X_test, y_test, scorer, - score_params=None, + score_params=score_params_test, error_score=error_score, ) ) @@ -2112,7 +2192,7 @@ def _incremental_fit_estimator( X_train, y_train, scorer, - score_params=None, + score_params=score_params_train, error_score=error_score, ) ) @@ -2220,9 +2300,8 @@ def validation_curve( ``cv`` default value if None changed from 3-fold to 5-fold. scoring : str or callable, default=None - A str (see model evaluation documentation) or - a scorer callable object / function with signature - ``scorer(estimator, X, y)``. + A str (see :ref:`scoring_parameter`) or a scorer callable object / function with + signature ``scorer(estimator, X, y)``. n_jobs : int, default=None Number of jobs to run in parallel. Training the estimator and computing diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index a1a860b243249..679c0052e3956 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -1535,7 +1535,7 @@ def test_learning_curve_with_shuffle(): ) -def test_learning_curve_fit_params(): +def test_learning_curve_params(): X = np.arange(100).reshape(10, 10) y = np.array([0] * 5 + [1] * 5) clf = CheckingClassifier(expected_sample_weight=True) @@ -1547,14 +1547,14 @@ def test_learning_curve_fit_params(): err_msg = r"sample_weight.shape == \(1,\), expected \(2,\)!" with pytest.raises(ValueError, match=err_msg): learning_curve( - clf, X, y, error_score="raise", fit_params={"sample_weight": np.ones(1)} + clf, X, y, error_score="raise", params={"sample_weight": np.ones(1)} ) learning_curve( - clf, X, y, error_score="raise", fit_params={"sample_weight": np.ones(10)} + clf, X, y, error_score="raise", params={"sample_weight": np.ones(10)} ) -def test_learning_curve_incremental_learning_fit_params(): +def test_learning_curve_incremental_learning_params(): X, y = make_classification( n_samples=30, n_features=1, @@ -1587,7 +1587,7 @@ def test_learning_curve_incremental_learning_fit_params(): exploit_incremental_learning=True, train_sizes=np.linspace(0.1, 1.0, 10), error_score="raise", - fit_params={"sample_weight": np.ones(3)}, + params={"sample_weight": np.ones(3)}, ) learning_curve( @@ -1598,7 +1598,7 @@ def test_learning_curve_incremental_learning_fit_params(): exploit_incremental_learning=True, train_sizes=np.linspace(0.1, 1.0, 10), error_score="raise", - fit_params={"sample_weight": np.ones(2)}, + params={"sample_weight": np.ones(2)}, ) @@ -2481,34 +2481,34 @@ def test_cross_validate_return_indices(global_random_seed): assert_array_equal(test_indices[split_idx], expected_test_idx) -# Tests for metadata routing in cross_val* -# ======================================== +# Tests for metadata routing in cross_val* and learning_curve +# =========================================================== -# TODO(1.6): remove this test in 1.6 -def test_cross_validate_fit_param_deprecation(): +# TODO(1.6): remove `cross_validate` and `cross_val_predict` from this test in 1.6 and +# `learning_curve` in 1.8 +@pytest.mark.parametrize("func", [cross_validate, cross_val_predict, learning_curve]) +def test_fit_param_deprecation(func): """Check that we warn about deprecating `fit_params`.""" with pytest.warns(FutureWarning, match="`fit_params` is deprecated"): - cross_validate(estimator=ConsumingClassifier(), X=X, y=y, cv=2, fit_params={}) + func(estimator=ConsumingClassifier(), X=X, y=y, cv=2, fit_params={}) with pytest.raises( ValueError, match="`params` and `fit_params` cannot both be provided" ): - cross_validate( - estimator=ConsumingClassifier(), X=X, y=y, fit_params={}, params={} - ) + func(estimator=ConsumingClassifier(), X=X, y=y, fit_params={}, params={}) @pytest.mark.usefixtures("enable_slep006") @pytest.mark.parametrize( - "cv_method", [cross_validate, cross_val_score, cross_val_predict] + "func", [cross_validate, cross_val_score, cross_val_predict, learning_curve] ) -def test_groups_with_routing_validation(cv_method): +def test_groups_with_routing_validation(func): """Check that we raise an error if `groups` are passed to the cv method instead of `params` when metadata routing is enabled. """ with pytest.raises(ValueError, match="`groups` can only be passed if"): - cv_method( + func( estimator=ConsumingClassifier(), X=X, y=y, @@ -2518,14 +2518,14 @@ def test_groups_with_routing_validation(cv_method): @pytest.mark.usefixtures("enable_slep006") @pytest.mark.parametrize( - "cv_method", [cross_validate, cross_val_score, cross_val_predict] + "func", [cross_validate, cross_val_score, cross_val_predict, learning_curve] ) -def test_passed_unrequested_metadata(cv_method): +def test_passed_unrequested_metadata(func): """Check that we raise an error when passing metadata that is not requested.""" err_msg = re.escape("but are not explicitly set as requested or not requested") with pytest.raises(ValueError, match=err_msg): - cv_method( + func( estimator=ConsumingClassifier(), X=X, y=y, @@ -2535,9 +2535,9 @@ def test_passed_unrequested_metadata(cv_method): @pytest.mark.usefixtures("enable_slep006") @pytest.mark.parametrize( - "cv_method", [cross_validate, cross_val_score, cross_val_predict] + "func", [cross_validate, cross_val_score, cross_val_predict, learning_curve] ) -def test_cross_validate_routing(cv_method): +def test_validation_functions_routing(func): """Check that the respective cv method is properly dispatching the metadata to the consumer.""" scorer_registry = _Registry() @@ -2552,6 +2552,7 @@ def test_cross_validate_routing(cv_method): estimator = ConsumingClassifier(registry=estimator_registry).set_fit_request( sample_weight="fit_sample_weight", metadata="fit_metadata" ) + n_samples = _num_samples(X) rng = np.random.RandomState(0) score_weights = rng.rand(n_samples) @@ -2563,8 +2564,9 @@ def test_cross_validate_routing(cv_method): extra_params = { cross_validate: dict(scoring=dict(my_scorer=scorer, accuracy="accuracy")), - # cross_val_score doesn't support multiple scorers + # cross_val_score and learning_curve don't support multiple scorers: cross_val_score: dict(scoring=scorer), + learning_curve: dict(scoring=scorer), # cross_val_predict doesn't need a scorer cross_val_predict: dict(), } @@ -2576,22 +2578,22 @@ def test_cross_validate_routing(cv_method): fit_metadata=fit_metadata, ) - if cv_method is not cross_val_predict: + if func is not cross_val_predict: params.update( score_weights=score_weights, score_metadata=score_metadata, ) - cv_method( + func( estimator, X=X, y=y, cv=splitter, - **extra_params[cv_method], + **extra_params[func], params=params, ) - if cv_method is not cross_val_predict: + if func is not cross_val_predict: # cross_val_predict doesn't need a scorer assert len(scorer_registry) for _scorer in scorer_registry: @@ -2623,5 +2625,42 @@ def test_cross_validate_routing(cv_method): ) +@pytest.mark.usefixtures("enable_slep006") +def test_learning_curve_exploit_incremental_learning_routing(): + """Test that learning_curve routes metadata to the estimator correctly while + partial_fitting it with `exploit_incremental_learning=True`.""" + + n_samples = _num_samples(X) + rng = np.random.RandomState(0) + fit_sample_weight = rng.rand(n_samples) + fit_metadata = rng.rand(n_samples) + + estimator_registry = _Registry() + estimator = ConsumingClassifier( + registry=estimator_registry + ).set_partial_fit_request( + sample_weight="fit_sample_weight", metadata="fit_metadata" + ) + + learning_curve( + estimator, + X=X, + y=y, + cv=ConsumingSplitter(), + exploit_incremental_learning=True, + params=dict(fit_sample_weight=fit_sample_weight, fit_metadata=fit_metadata), + ) + + assert len(estimator_registry) + for _estimator in estimator_registry: + check_recorded_metadata( + obj=_estimator, + method="partial_fit", + split_params=("sample_weight", "metadata"), + sample_weight=fit_sample_weight, + metadata=fit_metadata, + ) + + # End of metadata routing tests # ============================= diff --git a/sklearn/utils/_metadata_requests.py b/sklearn/utils/_metadata_requests.py index f730539621177..02a79bb8a6f20 100644 --- a/sklearn/utils/_metadata_requests.py +++ b/sklearn/utils/_metadata_requests.py @@ -999,8 +999,9 @@ def _route_params(self, *, params, method, parent, caller): def route_params(self, *, caller, params): """Return the input parameters requested by child objects. - The output of this method is a bunch, which includes the metadata for all - methods of each child object that is used in the router's `caller` method. + The output of this method is a :class:`~sklearn.utils.Bunch`, which includes the + metadata for all methods of each child object that is used in the router's + `caller` method. If the router is also a consumer, it also checks for warnings of `self`'s/consumer's requested metadata. diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index cdda749ec70a2..4e25750290a7a 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -488,7 +488,7 @@ def indexable(*iterables): Checks consistent length, passes through None, and ensures that everything can be indexed by converting sparse matrices to csr and converting - non-interable objects to arrays. + non-iterable objects to arrays. Parameters ----------