8000 FIX `CalibratedClassifierCV` should not ignore `sample_weight` if estimator does not support it by glemaitre · Pull Request #21143 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX CalibratedClassifierCV should not ignore sample_weight if estimator does not support it #21143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,24 @@

.. currentmodule:: sklearn

.. _changes_1_0_1:

Version 1.0.1
=============

**In Development**

Changelog
---------

:mod:`sklearn.calibration`
..........................

- |Fix| Raise an error instead of silently ignoring `sample_weight` in
:class:`calibration.CalibratedClassifierCV` when the underlying estimator
does not support it.
:pr:`21143` by :user:`Guillaume Lemaitre <glemaitre>`.

.. _changes_1_0:

Version 1.0.0
Expand Down
30 changes: 12 additions & 18 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#
# License: BSD 3 clause

import warnings
from inspect import signature
from functools import partial

Expand Down Expand Up @@ -37,7 +36,10 @@

from .utils.multiclass import check_classification_targets
from .utils.fixes import delayed
from .utils.validation import check_is_fitted, check_consistent_length
from .utils.validation import (
check_is_fitted,
check_consistent_length,
)
from .utils.validation import _check_sample_weight, _num_samples
from .utils import _safe_indexing
from .isotonic import IsotonicRegression
Expand Down Expand Up @@ -302,16 +304,12 @@ def fit(self, X, y, sample_weight=None):

# sample_weight checks
fit_parameters = signature(base_estimator.fit).parameters
supports_sw = "sample_weight" in fit_parameters
if "sample_weight" not in fit_parameters and sample_weight is not None:
raise ValueError(
f"The estimator {base_estimator} does not support sample_weight"
)
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)
if not supports_sw:
estimator_name = type(base_estimator).__name__
warnings.warn(
f"Since {estimator_name} does not support "
"sample_weights, sample weights will only be"
" used for the calibration itself."
)

# Check that each cross-validation fold can have at least one
# example per class
Expand Down Expand Up @@ -343,7 +341,6 @@ def fit(self, X, y, sample_weight=None):
test=test,
method=self.method,
classes=self.classes_,
supports_sw=supports_sw,
sample_weight=sample_weight,
)
for train, test in cv.split(X, y)
Expand All @@ -364,7 +361,7 @@ def fit(self, X, y, sample_weight=None):
pred_method, method_name, X, n_classes
)

if sample_weight is not None and supports_sw:
if sample_weight is not None:
this_estimator.fit(X, y, sample_weight)
else:
this_estimator.fit(X, y)
Expand Down Expand Up @@ -443,7 +440,7 @@ def _more_tags(self):


def _fit_classifier_calibrator_pair(
estimator, X, y, train, test, supports_sw, method, classes, sample_weight=None
estimator, X, y, train, test, method, classes, sample_weight=None
):
"""Fit a classifier/calibration pair on a given train/test split.

Expand All @@ -468,9 +465,6 @@ def _fit_classifier_calibrator_pair(
test : ndarray, shape (n_test_indicies,)
Indices of the testing subset.

supports_sw : bool
Whether or not the `estimator` supports sample weights.

method : {'sigmoid', 'isotonic'}
Method to use for calibration.

Expand All @@ -486,14 +480,14 @@ def _fit_classifier_calibrator_pair(
"""
X_train, y_train = _safe_indexing(X, train), _safe_indexing(y, train)
X_test, y_test = _safe_indexing(X, test), _safe_indexing(y, test)
if supports_sw and sample_weight is not None:
if sample_weight is not None:
sw_train = _safe_indexing(sample_weight, train)
sw_test = _safe_indexing(sample_weight, test)
else:
sw_train = None
sw_test = None

if supports_sw:
if sample_weight is not None:
estimator.fit(X_train, y_train, sample_weight=sw_train)
else:
estimator.fit(X_train, y_train)
Expand Down
75 changes: 65 additions & 10 deletions sklearn/tests/test_metaestimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest

from sklearn.base import BaseEstimator
from sklearn.base import clone, BaseEstimator
from sklearn.base import is_regressor
from sklearn.datasets import make_classification
from sklearn.utils import all_estimators
Expand All @@ -21,7 +21,7 @@
from sklearn.exceptions import NotFittedError
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.preprocessing import StandardScaler, MaxAbsScaler
from sklearn.preprocessing import StandardScaler, MaxAbsScaler, FunctionTransformer


class DelegatorData:
Expand Down Expand Up @@ -185,21 +185,35 @@ def score(self, X, y, *args, **kwargs):
)


def _generate_meta_estimator_instances_with_pipeline():
"""Generate instances of meta-estimators fed with a pipeline
def _generate_meta_estimator_instances_with_pipeline(first_step=None):
"""Generate instances of meta-estimators fed with a pipeline.

Are considered meta-estimators all estimators accepting one of "estimator",
"base_estimator" or "estimators".

Parameters
----------
first_step : None or estimator instance, default=None
The first step of the pipeline to pass to the meta-estimator.
If `None`, a `TfidfVectorizer` is used.

Yields
------
estimator : estimator instance
A meta-estimator instance.

"""
if first_step is None:
first_step = TfidfVectorizer()
for _, Estimator in sorted(all_estimators()):
sig = set(signature(Estimator).parameters)

if "estimator" in sig or "base_estimator" in sig or "regressor" in sig:
if is_regressor(Estimator):
estimator = make_pipeline(TfidfVectorizer(), Ridge())
estimator = make_pipeline(first_step, Ridge())
param_grid = {"ridge__alpha": [0.1, 1.0]}
else:
estimator = make_pipeline(TfidfVectorizer(), LogisticRegression())
estimator = make_pipeline(first_step, LogisticRegression())
param_grid = {"logisticregression__C": [0.1, 1.0]}

if "param_grid" in sig or "param_distributions" in sig:
Expand All @@ -212,10 +226,10 @@ def _generate_meta_estimator_instances_with_pipeline():
elif "transformer_list" in sig:
# FeatureUnion
transformer_list = [
("trans1", make_pipeline(TfidfVectorizer(), MaxAbsScaler())),
("trans1", make_pipeline(first_step, MaxAbsScaler())),
(
"trans2",
make_pipeline(TfidfVectorizer(), StandardScaler(with_mean=False)),
make_pipeline(first_step, StandardScaler(with_mean=False)),
),
]
yield Estimator(transformer_list)
Expand All @@ -224,8 +238,8 @@ def _generate_meta_estimator_instances_with_pipeline():
# stacking, voting
if is_regressor(Estimator):
estimator = [
("est1", make_pipeline(TfidfVectorizer(), Ridge(alpha=0.1))),
("est2", make_pipeline(TfidfVectorizer(), Ridge(alpha=1))),
("est1", make_pipeline(first_step, Ridge(alpha=0.1))),
("est2", make_pipeline(first_step, Ridge(alpha=1))),
]
else:
estimator = [
Expand Down Expand Up @@ -299,3 +313,44 @@ def test_meta_estimators_delegate_data_validation(estimator):

# n_features_in_ should not be defined since data is not tabular data.
assert not hasattr(estimator, "n_features_in_")


METAESTIMATORS_ACCEPTING_SAMPLE_WEIGHTS_IN_PIPELINE = [
# AdaBoostRegressor can safely accepts `sample_weight` even with a
# `Pipeline` because the weights are not used when calling `pipeline.fit`.
"AdaBoostRegressor",
]


@pytest.mark.parametrize(
"estimator_orig",
[
est
for est in _generate_meta_estimator_instances_with_pipeline(
first_step=FunctionTransformer()
)
if est.__class__.__name__
not in METAESTIMATORS_ACCEPTING_SAMPLE_WEIGHTS_IN_PIPELINE
],
ids=_get_meta_estimator_id,
)
def test_meta_estimators_sample_weight_with_pipeline(estimator_orig):
"""Check that passing a `Pipeline` with `sample_weight` raises an error.

FIXME: in the future, `Pipeline` should be able to delegate `sample_weight`
to the inner estimator(s). An error should not be raised then.
"""
estimator = clone(estimator_orig)
rng = np.random.RandomState(0)
set_random_state(estimator)

n_samples = 100
X = rng.randn(n_samples, 10)
y = rng.randint(3, size=n_samples)
y = _enforce_estimator_tags_y(estimator, y)

sample_weight = np.ones(y.shape[0])
sample_weight[0] = 10.0

with pytest.raises((ValueError, TypeError), match="sample.*weight"):
estimator.fit(X, y, sample_weight=sample_weight)
0