8000 TST check equivalence sample_weight in CalibratedClassifierCV by glemaitre · Pull Request #21179 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

TST check equivalence sample_weight in CalibratedClassifierCV #21179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@

.. currentmodule:: sklearn

.. _changes_1_0_1:

Version 1.0.1
=============

**In Development**

Changelog
---------

:mod:`sklearn.calibration`
..........................

- |Fix| Fixed :class:`calibration.CalibratedClassifierCV` to take into account
`sample_weight` when computing the base estimator prediction when
`ensemble=False`.
:pr:`20638` by :user:`Julien Bohné <JulienB-78>`.

- |Fix| Fixed a bug in :class:`calibration.CalibratedClassifierCV` with
`method="sigmoid"` that was ignoring the `sample_weight` when computing the
the Bayesian priors.
:pr:`21179` by :user:`Guillaume Lemaitre <glemaitre>`.

.. _changes_1_0:

Version 1.0.0
Expand Down
47 changes: 33 additions & 14 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def fit(self, X, y, sample_weight=None):
"""
check_classification_targets(y)
X, y = indexable(X, y)
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

if self.base_estimator is None:
# we want all classifiers that don't expose a random_state
Expand Down Expand Up @@ -303,15 +305,17 @@ def fit(self, X, y, sample_weight=None):
# sample_weight checks
fit_parameters = signature(base_estimator.fit).parameters
supports_sw = "sample_weight" in fit_parameters
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)
if not supports_sw:
estimator_name = type(base_estimator).__name__
warnings.warn(
f"Since {estimator_name} does not support "
"sample_weights, sample weights will only be"
" used for the calibration itself."
)
if sample_weight is not None and not supports_sw:
estimator_name = type(base_estimator).__name__
warnings.warn(
f"Since {estimator_name} does not appear to accept sample_weight, "
"sample weights will only be used for the calibration itself. This "
"can be caused by a limitation of the current scikit-learn API. "
"See the following issue for more details: "
"https://github.com/scikit-learn/scikit-learn/issues/21134. Be "
"warned that the result of the calibration is likely to be "
"incorrect."
)

# Check that each cross-validation fold can have at least one
# example per class
Expand Down Expand Up @@ -351,6 +355,11 @@ def fit(self, X, y, sample_weight=None):
else:
this_estimator = clone(base_estimator)
_, method_name = _get_prediction_method(this_estimator)
fit_params = (
{"sample_weight": sample_weight}
if sample_weight is not None and supports_sw
else None
)
pred_method = partial(
cross_val_predict,
estimator=this_estimator,
Expand All @@ -359,6 +368,7 @@ def fit(self, X, y, sample_weight=None):
cv=cv,
method=method_name,
n_jobs=self.n_jobs,
fit_params=fit_params,
)
predictions = _compute_predictions(
pred_method, method_name, X, n_classes
Expand Down Expand Up @@ -436,7 +446,9 @@ def _more_tags(self):
return {
"_xfail_checks": {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
"Due to the cross-validation and sample ordering, removing a sample"
" is not strictly equal to putting is weight to zero. Specific unit"
" tests are added for CalibratedClassifierCV specifically."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very informative, thanks!

),
}
}
Expand Down Expand Up @@ -760,10 +772,17 @@ def _sigmoid_calibration(predictions, y, sample_weight=None):

F = predictions # F follows Platt's notations

# Bayesian priors (see Platt end of section 2.2)
prior0 = float(np.sum(y <= 0))
prior1 = y.shape[0] - prior0
T = np.zeros(y.shape)
# Bayesian priors (see Platt end of section 2.2):
# It corresponds to the number of samples, taking into account the
# `sample_weight`.
mask_negative_samples = y <= 0
if sample_weight is not None:
prior0 = (sample_weight[mask_negative_samples]).sum()
prior1 = (sample_weight[~mask_negative_samples]).sum()
else:
prior0 = float(np.sum(mask_negative_samples))
prior1 = y.shape[0] - prior0
T = np.zeros_like(y, dtype=np.float64)
T[y > 0] = (prior1 + 1.0) / (prior1 + 2.0)
T[y <= 0] = 1.0 / (prior0 + 2.0)
T1 = 1.0 - T
Expand Down
96 changes: 95 additions & 1 deletion sklearn/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy.testing import assert_allclose
from scipy import sparse

from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, clone
from sklearn.dummy import DummyClassifier
from sklearn.model_selection import LeaveOneOut, train_test_split

Expand Down Expand Up @@ -784,3 +784,97 @@ def test_calibration_display_ref_line(pyplot, iris_data_binary):

labels = viz2.ax_.get_legend_handles_labels()[1]
assert labels.count("Perfectly calibrated") == 1


@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
@pytest.mark.parametrize("ensemble", [True, False])
def test_calibrated_classifier_cv_double_sample_weights_equivalence(method, ensemble):
"""Check that passing repeating twice the dataset `X` is equivalent to
passing a `sample_weight` with a factor 2."""
X, y = load_iris(return_X_y=True)
# Scale the data to avoid any convergence issue
X = StandardScaler().fit_transform(X)
# Only use 2 classes
X, y = X[:100], y[:100]
sample_weight = np.ones_like(y) * 2

# Interlace the data such that a 2-fold cross-validation will be equivalent
# to using the original dataset with a sample weights of 2
X_twice = np.zeros((X.shape[0] * 2, X.shape[1]), dtype=X.dtype)
X_twice[::2, :] = X
X_twice[1::2, :] = X
y_twice = np.zeros(y.shape[0] * 2, dtype=y.dtype)
y_twice[::2] = y
y_twice[1::2] = y

base_estimator = LogisticRegression()
calibrated_clf_without_weights = CalibratedClassifierCV(
base_estimator,
method=method,
ensemble=ensemble,
cv=2,
)
calibrated_clf_with_weights = clone(calibrated_clf_without_weights)

calibrated_clf_with_weights.fit(X, y, sample_weight=sample_weight)
calibrated_clf_without_weights.fit(X_twice, y_twice)

# Check that the underlying fitted estimators have the same coefficients
for est_with_weights, est_without_weights in zip(
calibrated_clf_with_weights.calibrated_classifiers_,
calibrated_clf_without_weights.calibrated_classifiers_,
):
assert_allclose(
est_with_weights.base_estimator.coef_,
est_without_weights.base_estimator.coef_,
)

# Check that the predictions are the same
y_pred_with_weights = calibrated_clf_with_weights.predict_proba(X)
y_pred_without_weights = calibrated_clf_without_weights.predict_proba(X)

assert_allclose(y_pred_with_weights, y_pred_without_weights)


@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
@pytest.mark.parametrize("ensemble", [True, False])
def test_calibrated_classifier_cv_zeros_sample_weights_equivalence(method, ensemble):
"""Check that passing removing some sample from the dataset `X` is
equivalent to passing a `sample_weight` with a factor 0."""
X, y = load_iris(return_X_y=True)
# Scale the data to avoid any convergence issue
X = StandardScaler().fit_transform(X)
# Only use 2 classes and select samples such that 2-fold cross-validation
# split will lead to an equivalence with a `sample_weight` of 0
X = np.vstack((X[:40], X[50:90]))
y = np.hstack((y[:40], y[50:90]))
sample_weight = np.zeros_like(y)
sample_weight[::2] = 1

base_estimator = LogisticRegression()
calibrated_clf_without_weights = CalibratedClassifierCV(
base_estimator,
method=method,
ensemble=ensemble,
cv=2,
)
calibrated_clf_with_weights = clone(calibrated_clf_without_weights)

calibrated_clf_with_weights.fit(X, y, sample_weight=sample_weight)
calibrated_clf_without_weights.fit(X[::2], y[::2])

# Check that the underlying fitted estimators have the same coefficients
for est_with_weights, est_without_weights in zip(
calibrated_clf_with_weights.calibrated_classifiers_,
calibrated_clf_without_weights.calibrated_classifiers_,
):
assert_allclose(
est_with_weights.base_estimator.coef_,
est_without_weights.base_estimator.coef_,
)

# Check that the predictions are the same
y_pred_with_weights = calibrated_clf_with_weights.predict_proba(X)
y_pred_without_weights = calibrated_clf_without_weights.predict_proba(X)

assert_allclose(y_pred_with_weights, y_pred_without_weights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice and much needed new test!

0