8000 ENH add the parameter prefit in the FixedThresholdClassifier (#29067) · scikit-learn-bot/scikit-learn@e5ed851 · GitHub
[go: up one dir, main page]

Skip to content

Commit e5ed851

Browse files
authored
ENH add the parameter prefit in the FixedThresholdClassifier (scikit-learn#29067)
1 parent d8ee3fc commit e5ed851

File tree

4 files changed

+68
-3
lines changed

4 files changed

+68
-3
lines changed

doc/modules/classification_threshold.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ Manually setting the decision threshold
143143

144144
The previous sections discussed strategies to find an optimal decision threshold. It is
145145
also possible to manually set the decision threshold using the class
146-
:class:`~sklearn.model_selection.FixedThresholdClassifier`.
146+
:class:`~sklearn.model_selection.FixedThresholdClassifier`. In case that you don't want
147+
to refit the model when calling `fit`, you can set the parameter `prefit=True`.
147148

148149
Examples
149150
--------

doc/whats_new/v1.6.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@ Changelog
9595
whether to raise an exception if a subset of the scorers in multimetric scoring fails
9696
or to return an error code. :pr:`28992` by :user:`Stefanie Senger <StefanieSenger>`.
9797

98+
:mod:`sklearn.model_selection`
99+
..............................
100+
101+
- |Enhancement| Add the parameter `prefit` to
102+
:class:`model_selection.FixedThresholdClassifier` allowing the use of a pre-fitted
103+
estimator without re-fitting it.
104+
:pr:`29067` by :user:`Guillaume Lemaitre <glemaitre>`.
105+
98106
Thanks to everyone who has contributed to the maintenance and improvement of
99107
the project since version 1.5, including:
100108

sklearn/model_selection/_classification_threshold.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,13 @@ class FixedThresholdClassifier(BaseThresholdClassifier):
271271
If the method is not implemented by the classifier, it will raise an
272272
error.
273273
274+
prefit : bool, default=False
275+
Whether a pre-fitted model is expected to be passed into the constructor
276+
directly or not. If `True`, `estimator` must be a fitted estimator. If `False`,
277+
`estimator` is fitted and updated by calling `fit`.
278+
279+
.. versionadded:: 1.6
280+
274281
Attributes
275282
----------
276283
estimator_ : estimator instance
@@ -322,6 +329,7 @@ class FixedThresholdClassifier(BaseThresholdClassifier):
322329
**BaseThresholdClassifier._parameter_constraints,
323330
"threshold": [StrOptions({"auto"}), Real],
324331
"pos_label": [Real, str, "boolean", None],
332+
"prefit": ["boolean"],
325333
}
326334

327335
def __init__(
@@ -331,10 +339,12 @@ def __init__(
331339
threshold="auto",
332340
pos_label=None,
333341
response_method="auto",
342+
prefit=False,
334343
):
335344
super().__init__(estimator=estimator, response_method=response_method)
336345
self.pos_label = pos_label
337346
self.threshold = threshold
347+
self.prefit = prefit
338348

339349
def _fit(self, X, y, **params):
340350
"""Fit the classifier.
@@ -357,7 +367,13 @@ def _fit(self, X, y, **params):
357367
Returns an instance of self.
358368
"""
359369
routed_params = process_routing(self, "fit", **params)
360-
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
370+
if self.prefit:
371+
check_is_fitted(self.estimator)
372+
self.estimator_ = self.estimator
373+
else:
374+
self.estimator_ = clone(self.estimator).fit(
375+
X, y, **routed_params.estimator.fit
376+
)
361377
return self
362378

363379
def predict(self, X):

sklearn/model_selection/tests/test_classification_threshold.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import pytest
33

4-
from sklearn.base import clone
4+
from sklearn.base import BaseEstimator, ClassifierMixin, clone
55
from sklearn.datasets import (
66
load_breast_cancer,
77
load_iris,
@@ -682,3 +682,43 @@ def test_fixed_threshold_classifier_metadata_routing():
682682
classifier_default_threshold = FixedThresholdClassifier(estimator=clone(classifier))
683683
classifier_default_threshold.fit(X, y, sample_weight=sample_weight)
684684
assert_allclose(classifier_default_threshold.estimator_.coef_, classifier.coef_)
685+
686+
687+
class ClassifierLoggingFit(ClassifierMixin, BaseEstimator):
688+
"""Classifier that logs the number of `fit` calls."""
689+
690+
def __init__(self, fit_calls=0):
691+
self.fit_calls = fit_calls
692+
693+
def fit(self, X, y, **fit_params):
694+
self.fit_calls += 1
695+
self.is_fitted_ = True
696+
return self
697+
698+
def predict_proba(self, X):
699+
return np.ones((X.shape[0], 2), np.float64) # pragma: nocover
700+
701+
702+
def test_fixed_threshold_classifier_prefit():
703+
"""Check the behaviour of the `FixedThresholdClassifier` with the `prefit`
704+
parameter."""
705+
X, y = make_classification(random_state=0)
706+
707+
estimator = ClassifierLoggingFit()
708+
model = FixedThresholdClassifier(estimator=estimator, prefit=True)
709+
with pytest.raises(NotFittedError):
710+
model.fit(X, y)
711+
712+
# check that we don't clone the classifier when `prefit=True`.
713+
estimator.fit(X, y)
714+
model.fit(X, y)
715+
assert estimator.fit_calls == 1
716+
assert model.estimator_ is estimator
717+
718+
# check that we clone the classifier when `prefit=False`.
719+
estimator = ClassifierLoggingFit()
720+
model = FixedThresholdClassifier(estimator=estimator, prefit=False)
721+
model.fit(X, y)
722+
assert estimator.fit_calls == 0
723+
assert model.estimator_.fit_calls == 1
724+
assert model.estimator_ is not estimator

0 commit comments

Comments
 (0)
0