8000 ENH add the parameter prefit in the FixedThresholdClassifier by glemaitre · Pull Request #29067 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH add the parameter prefit in the FixedThresholdClassifier #29067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/modules/classification_threshold.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ Manually setting the decision threshold

The previous sections discussed strategies to find an optimal decision threshold. It is
also possible to manually set the decision threshold using the class
:class:`~sklearn.model_selection.FixedThresholdClassifier`.
:class:`~sklearn.model_selection.FixedThresholdClassifier`. In case that you don't want
to refit the model when calling `fit`, you can set the parameter `prefit=True`.

Examples
--------
Expand Down
8 changes: 8 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ Changelog
whether to raise an exception if a subset of the scorers in multimetric scoring fails
or to return an error code. :pr:`28992` by :user:`Stefanie Senger <StefanieSenger>`.

:mod:`sklearn.model_selection`
..............................

- |Enhancement| Add the parameter `prefit` to
:class:`model_selection.FixedThresholdClassifier` allowing the use of a pre-fitted
estimator without re-fitting it.
:pr:`29067` by :user:`Guillaume Lemaitre <glemaitre>`.

Thanks to everyone who has contributed to the maintenance and improvement of
the project since version 1.5, including:

Expand Down
18 changes: 17 additions & 1 deletion sklearn/model_selection/_classification_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,13 @@ class FixedThresholdClassifier(BaseThresholdClassifier):
If the method is not implemented by the classifier, it will raise an
error.

prefit : bool, default=False
Whether a pre-fitted model is expected to be passed into the constructor
directly or not. If `True`, `estimator` must be a fitted estimator. If `False`,
`estimator` is fitted and updated by calling `fit`.

.. versionadded:: 1.6

Attributes
----------
estimator_ : estimator instance
Expand Down Expand Up @@ -322,6 +329,7 @@ class FixedThresholdClassifier(BaseThresholdClassifier):
**BaseThresholdClassifier._parameter_constraints,
"threshold": [StrOptions({"auto"}), Real],
"pos_label": [Real, str, "boolean", None],
"prefit": ["boolean"],
}

def __init__(
Expand All @@ -331,10 +339,12 @@ def __init__(
threshold="auto",
pos_label=None,
response_method="auto",
prefit=False,
):
super().__init__(estimator=estimator, response_method=response_method)
self.pos_label = pos_label
self.threshold = threshold
self.prefit = prefit

def _fit(self, X, y, **params):
"""Fit the classifier.
Expand All @@ -357,7 +367,13 @@ def _fit(self, X, y, **params):
Returns an instance of self.
"""
routed_params = process_routing(self, "fit", **params)
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
if self.prefit:
check_is_fitted(self.estimator)
self.estimator_ = self.estimator
else:
self.estimator_ = clone(self.estimator).fit(
X, y, **routed_params.estimator.fit
)
return self

def predict(self, X):
Expand Down
42 changes: 41 additions & 1 deletion sklearn/model_selection/tests/test_classification_threshold.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from sklearn.base import clone
from sklearn.base import BaseEstimator, ClassifierMixin, clone
from sklearn.datasets import (
load_breast_cancer,
load_iris,
Expand Down Expand Up @@ -682,3 +682,43 @@ def test_fixed_threshold_classifier_metadata_routing():
classifier_default_threshold = FixedThresholdClassifier(estimator=clone(classifier))
classifier_default_threshold.fit(X, y, sample_weight=sample_weight)
assert_allclose(classifier_default_threshold.estimator_.coef_, classifier.coef_)


class ClassifierLoggingFit(ClassifierMixin, BaseEstimator):
"""Classifier that logs the number of `fit` calls."""

def __init__(self, fit_calls=0):
self.fit_calls = fit_calls

def fit(self, X, y, **fit_params):
self.fit_calls += 1
self.is_fitted_ = True
return self

def predict_proba(self, X):
return np.ones((X.shape[0], 2), np.float64) # pragma: nocover


def test_fixed_threshold_classifier_prefit():
"""Check the behaviour of the `FixedThresholdClassifier` with the `prefit`
parameter."""
X, y = make_classification(random_state=0)

estimator = ClassifierLoggingFit()
model = FixedThresholdClassifier(estimator=estimator, prefit=True)
with pytest.raises(NotFittedError):
model.fit(X, y)

# check that we don't clone the classifier when `prefit=True`.
estimator.fit(X, y)
model.fit(X, y)
assert estimator.fit_calls == 1
assert model.estimator_ is estimator

# check that we clone the classifier when `prefit=False`.
estimator = ClassifierLoggingFit()
model = FixedThresholdClassifier(estimator=estimator, prefit=False)
model.fit(X, y)
assert estimator.fit_calls == 0
assert model.estimator_.fit_calls == 1
assert model.estimator_ is not estimator
Loading
0