From 0a0afb563fae01e277bc4b90dfb6c321af1ddc6c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 21 May 2024 16:08:49 +0200 Subject: [PATCH 1/6] ENH add the parameter prefit in the FixedThresholdClassifier --- doc/modules/classification_threshold.rst | 3 +- doc/whats_new/v1.6.rst | 8 ++++ .../_classification_threshold.py | 18 ++++++- .../tests/test_classification_threshold.py | 47 ++++++++++++++++++- 4 files changed, 73 insertions(+), 3 deletions(-) diff --git a/doc/modules/classification_threshold.rst b/doc/modules/classification_threshold.rst index 712a094a43246..236c0736f7d23 100644 --- a/doc/modules/classification_threshold.rst +++ b/doc/modules/classification_threshold.rst @@ -143,7 +143,8 @@ Manually setting the decision threshold The previous sections discussed strategies to find an optimal decision threshold. It is also possible to manually set the decision threshold using the class -:class:`~sklearn.model_selection.FixedThresholdClassifier`. +:class:`~sklearn.model_selection.FixedThresholdClassifier`. In case that you don't want +to refit the model when calling `fit`, you can set the parameter `prefit=True`. Examples -------- diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 0e6844155c6fa..252d44ee0a302 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -74,6 +74,14 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123455 is the *pull request* number, not the issue number. +:mod:`sklearn.model_selection` +.............................. + +- |Enhancement| add the parameter `prefit` to + :class:`model_selection.FixedThresholdClassifier` allowing to use a pre-fitted + estimator without re-fitting it. + :pr:`xxx` by :user:`Guillaume Lemaitre `. + Thanks to everyone who has contributed to the maintenance and improvement of the project since version 1.5, including: diff --git a/sklearn/model_selection/_classification_threshold.py b/sklearn/model_selection/_classification_threshold.py index 1f891577b4680..8d67e87f8304c 100644 --- a/sklearn/model_selection/_classification_threshold.py +++ b/sklearn/model_selection/_classification_threshold.py @@ -271,6 +271,13 @@ class FixedThresholdClassifier(BaseThresholdClassifier): If the method is not implemented by the classifier, it will raise an error. + prefit : bool, default=False + Whether a prefit model is expected to be passed into the constructor directly or + not. If `True`, `estimator` must be a fitted estimator. If `False`, `estimator` + is fitted and updated by calling `fit`, respectively. + + .. versionadded:: 1.6 + Attributes ---------- estimator_ : estimator instance @@ -322,6 +329,7 @@ class FixedThresholdClassifier(BaseThresholdClassifier): **BaseThresholdClassifier._parameter_constraints, "threshold": [StrOptions({"auto"}), Real], "pos_label": [Real, str, "boolean", None], + "refit": ["boolean"], } def __init__( @@ -331,10 +339,12 @@ def __init__( threshold="auto", pos_label=None, response_method="auto", + prefit=False, ): super().__init__(estimator=estimator, response_method=response_method) self.pos_label = pos_label self.threshold = threshold + self.prefit = prefit def _fit(self, X, y, **params): """Fit the classifier. @@ -357,7 +367,13 @@ def _fit(self, X, y, **params): Returns an instance of self. """ routed_params = process_routing(self, "fit", **params) - self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit) + if self.prefit: + check_is_fitted(self.estimator) + self.estimator_ = self.estimator + else: + self.estimator_ = clone(self.estimator).fit( + X, y, **routed_params.estimator.fit + ) return self def predict(self, X): diff --git a/sklearn/model_selection/tests/test_classification_threshold.py b/sklearn/model_selection/tests/test_classification_threshold.py index f64edb2563c76..54badda301780 100644 --- a/sklearn/model_selection/tests/test_classification_threshold.py +++ b/sklearn/model_selection/tests/test_classification_threshold.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from sklearn.base import clone +from sklearn.base import BaseEstimator, ClassifierMixin, clone from sklearn.datasets import ( load_breast_cancer, load_iris, @@ -682,3 +682,48 @@ def test_fixed_threshold_classifier_metadata_routing(): classifier_default_threshold = FixedThresholdClassifier(estimator=clone(classifier)) classifier_default_threshold.fit(X, y, sample_weight=sample_weight) assert_allclose(classifier_default_threshold.estimator_.coef_, classifier.coef_) + + +class ClassifierLoggingFit(ClassifierMixin, BaseEstimator): + """Classifier that logs the number of `fit` calls.""" + + def __init__(self, fit_calls=0): + self.fit_calls = fit_calls + + def fit(self, X, y, **fit_params): + self.fit_calls += 1 + self.is_fitted_ = True + return self + + def predict_proba(self, X): + return np.ones((X.shape[0], 2), np.float64) + + +def test_fixed_threshold_classifier_prefit(): + """Check the behaviour of the `FixedThresholdClassifier` with the `prefit` + parameter.""" + X, y = make_classification(random_state=0) + + estimator = ClassifierLoggingFit() + model = FixedThresholdClassifier(estimator=estimator, prefit=True) + with pytest.raises(NotFittedError): + model.fit(X, y) + + # check that we don't clone the classifier when `prefit=True`. + estimator.fit(X, y) + model.fit(X, y) + assert estimator.fit_calls == 1 + assert model.estimator_ is estimator + + # check that we don't call `fit` of the underlying estimator as well. + model.fit(X, y) + assert estimator.fit_calls == 1 + assert model.estimator_ is estimator + + # check that we clone the classifier when `prefit=False`. + estimator = ClassifierLoggingFit() + model = FixedThresholdClassifier(estimator=estimator, prefit=False) + model.fit(X, y) + assert estimator.fit_calls == 0 + assert model.estimator_.fit_calls == 1 + assert model.estimator_ is not estimator From cb663b736f140af6455d5548b2298bb994cdc8dd Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 21 May 2024 16:10:29 +0200 Subject: [PATCH 2/6] DOC update the pr number --- doc/whats_new/v1.6.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 252d44ee0a302..d6f0e35a903a3 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -80,7 +80,7 @@ Changelog - |Enhancement| add the parameter `prefit` to :class:`model_selection.FixedThresholdClassifier` allowing to use a pre-fitted estimator without re-fitting it. - :pr:`xxx` by :user:`Guillaume Lemaitre `. + :pr:`29067` by :user:`Guillaume Lemaitre `. Thanks to everyone who has contributed to the maintenance and improvement of the project since version 1.5, including: From 1a5231b4590a299e6e55b4d64776729ea60feba3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 21 May 2024 16:24:42 +0200 Subject: [PATCH 3/6] FIX use prefit instead of refit in parameter validation --- sklearn/model_selection/_classification_threshold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/model_selection/_classification_threshold.py b/sklearn/model_selection/_classification_threshold.py index 8d67e87f8304c..5a590a86cbe8d 100644 --- a/sklearn/model_selection/_classification_threshold.py +++ b/sklearn/model_selection/_classification_threshold.py @@ -329,7 +329,7 @@ class FixedThresholdClassifier(BaseThresholdClassifier): **BaseThresholdClassifier._parameter_constraints, "threshold": [StrOptions({"auto"}), Real], "pos_label": [Real, str, "boolean", None], - "refit": ["boolean"], + "prefit": ["boolean"], } def __init__( From b5793d30873fa21dd5bf99d0725d9fe1cca8a2f6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 24 May 2024 10:05:31 +0200 Subject: [PATCH 4/6] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger --- sklearn/model_selection/_classification_threshold.py | 4 ++-- .../model_selection/tests/test_classification_threshold.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/model_selection/_classification_threshold.py b/sklearn/model_selection/_classification_threshold.py index 5a590a86cbe8d..c9615a262fbdb 100644 --- a/sklearn/model_selection/_classification_threshold.py +++ b/sklearn/model_selection/_classification_threshold.py @@ -272,9 +272,9 @@ class FixedThresholdClassifier(BaseThresholdClassifier): error. prefit : bool, default=False - Whether a prefit model is expected to be passed into the constructor directly or + Whether a pre-fitted model is expected to be passed into the constructor directly or not. If `True`, `estimator` must be a fitted estimator. If `False`, `estimator` - is fitted and updated by calling `fit`, respectively. + is fitted and updated by calling `fit`. .. versionadded:: 1.6 diff --git a/sklearn/model_selection/tests/test_classification_threshold.py b/sklearn/model_selection/tests/test_classification_threshold.py index 54badda301780..9c91aeeb03512 100644 --- a/sklearn/model_selection/tests/test_classification_threshold.py +++ b/sklearn/model_selection/tests/test_classification_threshold.py @@ -696,7 +696,7 @@ def fit(self, X, y, **fit_params): return self def predict_proba(self, X): - return np.ones((X.shape[0], 2), np.float64) + return np.ones((X.shape[0], 2), np.float64) # pragma: nocover def test_fixed_threshold_classifier_prefit(): From 2a600eb272bca416f5e0df56a2d0a79ebd6e7f38 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 24 May 2024 10:09:46 +0200 Subject: [PATCH 5/6] fix --- sklearn/model_selection/_classification_threshold.py | 6 +++--- .../model_selection/tests/test_classification_threshold.py | 5 ----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/sklearn/model_selection/_classification_threshold.py b/sklearn/model_selection/_classification_threshold.py index c9615a262fbdb..86d7efb74c826 100644 --- a/sklearn/model_selection/_classification_threshold.py +++ b/sklearn/model_selection/_classification_threshold.py @@ -272,9 +272,9 @@ class FixedThresholdClassifier(BaseThresholdClassifier): error. prefit : bool, default=False - Whether a pre-fitted model is expected to be passed into the constructor directly or - not. If `True`, `estimator` must be a fitted estimator. If `False`, `estimator` - is fitted and updated by calling `fit`. + Whether a pre-fitted model is expected to be passed into the constructor + directly or not. If `True`, `estimator` must be a fitted estimator. If `False`, + `estimator` is fitted and updated by calling `fit`. .. versionadded:: 1.6 diff --git a/sklearn/model_selection/tests/test_classification_threshold.py b/sklearn/model_selection/tests/test_classification_threshold.py index 9c91aeeb03512..77c4c20e99ef2 100644 --- a/sklearn/model_selection/tests/test_classification_threshold.py +++ b/sklearn/model_selection/tests/test_classification_threshold.py @@ -715,11 +715,6 @@ def test_fixed_threshold_classifier_prefit(): assert estimator.fit_calls == 1 assert model.estimator_ is estimator - # check that we don't call `fit` of the underlying estimator as well. - model.fit(X, y) - assert estimator.fit_calls == 1 - assert model.estimator_ is estimator - # check that we clone the classifier when `prefit=False`. estimator = ClassifierLoggingFit() model = FixedThresholdClassifier(estimator=estimator, prefit=False) From a99517f3508cc0c5304455ff95c5ffcabc33f140 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 30 May 2024 12:13:35 +0500 Subject: [PATCH 6/6] Update doc/whats_new/v1.6.rst --- doc/whats_new/v1.6.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index af22a5b5dbb39..dabe83fb53106 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -91,8 +91,8 @@ Changelog :mod:`sklearn.model_selection` .............................. -- |Enhancement| add the parameter `prefit` to - :class:`model_selection.FixedThresholdClassifier` allowing to use a pre-fitted +- |Enhancement| Add the parameter `prefit` to + :class:`model_selection.FixedThresholdClassifier` allowing the use of a pre-fitted estimator without re-fitting it. :pr:`29067` by :user:`Guillaume Lemaitre `.