diff --git a/doc/modules/classification_threshold.rst b/doc/modules/classification_threshold.rst index 712a094a43246..236c0736f7d23 100644 --- a/doc/modules/classification_threshold.rst +++ b/doc/modules/classification_threshold.rst @@ -143,7 +143,8 @@ Manually setting the decision threshold The previous sections discussed strategies to find an optimal decision threshold. It is also possible to manually set the decision threshold using the class -:class:`~sklearn.model_selection.FixedThresholdClassifier`. +:class:`~sklearn.model_selection.FixedThresholdClassifier`. In case that you don't want +to refit the model when calling `fit`, you can set the parameter `prefit=True`. Examples -------- diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 87814e102ad98..dabe83fb53106 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -88,6 +88,14 @@ Changelog whether to raise an exception if a subset of the scorers in multimetric scoring fails or to return an error code. :pr:`28992` by :user:`Stefanie Senger `. +:mod:`sklearn.model_selection` +.............................. + +- |Enhancement| Add the parameter `prefit` to + :class:`model_selection.FixedThresholdClassifier` allowing the use of a pre-fitted + estimator without re-fitting it. + :pr:`29067` by :user:`Guillaume Lemaitre `. + Thanks to everyone who has contributed to the maintenance and improvement of the project since version 1.5, including: diff --git a/sklearn/model_selection/_classification_threshold.py b/sklearn/model_selection/_classification_threshold.py index e090a3d042746..1d221d3388434 100644 --- a/sklearn/model_selection/_classification_threshold.py +++ b/sklearn/model_selection/_classification_threshold.py @@ -271,6 +271,13 @@ class FixedThresholdClassifier(BaseThresholdClassifier): If the method is not implemented by the classifier, it will raise an error. + prefit : bool, default=False + Whether a pre-fitted model is expected to be passed into the constructor + directly or not. If `True`, `estimator` must be a fitted estimator. If `False`, + `estimator` is fitted and updated by calling `fit`. + + .. versionadded:: 1.6 + Attributes ---------- estimator_ : estimator instance @@ -322,6 +329,7 @@ class FixedThresholdClassifier(BaseThresholdClassifier): **BaseThresholdClassifier._parameter_constraints, "threshold": [StrOptions({"auto"}), Real], "pos_label": [Real, str, "boolean", None], + "prefit": ["boolean"], } def __init__( @@ -331,10 +339,12 @@ def __init__( threshold="auto", pos_label=None, response_method="auto", + prefit=False, ): super().__init__(estimator=estimator, response_method=response_method) self.pos_label = pos_label self.threshold = threshold + self.prefit = prefit def _fit(self, X, y, **params): """Fit the classifier. @@ -357,7 +367,13 @@ def _fit(self, X, y, **params): Returns an instance of self. """ routed_params = process_routing(self, "fit", **params) - self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit) + if self.prefit: + check_is_fitted(self.estimator) + self.estimator_ = self.estimator + else: + self.estimator_ = clone(self.estimator).fit( + X, y, **routed_params.estimator.fit + ) return self def predict(self, X): diff --git a/sklearn/model_selection/tests/test_classification_threshold.py b/sklearn/model_selection/tests/test_classification_threshold.py index f64edb2563c76..77c4c20e99ef2 100644 --- a/sklearn/model_selection/tests/test_classification_threshold.py +++ b/sklearn/model_selection/tests/test_classification_threshold.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from sklearn.base import clone +from sklearn.base import BaseEstimator, ClassifierMixin, clone from sklearn.datasets import ( load_breast_cancer, load_iris, @@ -682,3 +682,43 @@ def test_fixed_threshold_classifier_metadata_routing(): classifier_default_threshold = FixedThresholdClassifier(estimator=clone(classifier)) classifier_default_threshold.fit(X, y, sample_weight=sample_weight) assert_allclose(classifier_default_threshold.estimator_.coef_, classifier.coef_) + + +class ClassifierLoggingFit(ClassifierMixin, BaseEstimator): + """Classifier that logs the number of `fit` calls.""" + + def __init__(self, fit_calls=0): + self.fit_calls = fit_calls + + def fit(self, X, y, **fit_params): + self.fit_calls += 1 + self.is_fitted_ = True + return self + + def predict_proba(self, X): + return np.ones((X.shape[0], 2), np.float64) # pragma: nocover + + +def test_fixed_threshold_classifier_prefit(): + """Check the behaviour of the `FixedThresholdClassifier` with the `prefit` + parameter.""" + X, y = make_classification(random_state=0) + + estimator = ClassifierLoggingFit() + model = FixedThresholdClassifier(estimator=estimator, prefit=True) + with pytest.raises(NotFittedError): + model.fit(X, y) + + # check that we don't clone the classifier when `prefit=True`. + estimator.fit(X, y) + model.fit(X, y) + assert estimator.fit_calls == 1 + assert model.estimator_ is estimator + + # check that we clone the classifier when `prefit=False`. + estimator = ClassifierLoggingFit() + model = FixedThresholdClassifier(estimator=estimator, prefit=False) + model.fit(X, y) + assert estimator.fit_calls == 0 + assert model.estimator_.fit_calls == 1 + assert model.estimator_ is not estimator