8000 TST check equivalence sample_weight in CalibratedClassifierCV (#21179) · scikit-learn/scikit-learn@17a788d · GitHub
[go: up one dir, main page]

Skip to content

Commit 17a788d

Browse files
glemaitreJulienB-78ogrisel
authored
TST check equivalence sample_weight in CalibratedClassifierCV (#21179)
Co-authored-by: JulienB-78 <jbohne78-github@yahoo.fr> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 39f7790 commit 17a788d

File tree

3 files changed

+151
-15
lines changed

3 files changed

+151
-15
lines changed

doc/whats_new/v1.0.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,29 @@
22

33
.. currentmodule:: sklearn
44

5+
.. _changes_1_0_1:
6+
7+
Version 1.0.1
8+
=============
9+
10+
**In Development**
11+
12+
Changelog
13+
---------
14+
15+
:mod:`sklearn.calibration`
16+
..........................
17+
18+
- |Fix| Fixed :class:`calibration.CalibratedClassifierCV` to take into account
19+
`sample_weight` when computing the base estimator prediction when
20+
`ensemble=False`.
21+
:pr:`20638` by :user:`Julien Bohné <JulienB-78>`.
22+
23+
- |Fix| Fixed a bug in :class:`calibration.CalibratedClassifierCV` with
24+
`method="sigmoid"` that was ignoring the `sample_weight` when computing the
25+
the Bayesian priors.
26+
:pr:`21179` by :user:`Guillaume Lemaitre <glemaitre>`.
27+
528
.. _changes_1_0:
629

730
Version 1.0.0

sklearn/calibration.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def fit(self, X, y, sample_weight=None):
267267
"""
268268
check_classification_targets(y)
269269
X, y = indexable(X, y)
270+
if sample_weight is not None:
271+
sample_weight = _check_sample_weight(sample_weight, X)
270272

271273
if self.base_estimator is None:
272274
# we want all classifiers that don't expose a random_state
@@ -303,15 +305,17 @@ def fit(self, X, y, sample_weight=None):
303305
# sample_weight checks
304306
fit_parameters = signature(base_estimator.fit).parameters
305307
supports_sw = "sample_weight" in fit_parameters
306-
if sample_weight is not None:
307-
sample_weight = _check_sample_weight(sample_weight, X)
308-
if not supports_sw:
309-
estimator_name = type(base_estimator).__name__
310-
warnings.warn(
311-
f"Since {estimator_name} does not support "
312-
"sample_weights, sample weights will only be"
313-
" used for the calibration itself."
314-
)
308+
if sample_weight is not None and not supports_sw:
309+
estimator_name = type(base_estimator).__name__
310+
warnings.warn(
311+
f"Since {estimator_name} does not appear to accept sample_weight, "
312+
"sample weights will only be used for the calibration itself. This "
313+
"can be caused by a limitation of the current scikit-learn API. "
314+
"See the following issue for more details: "
315+
"https://github.com/scikit-learn/scikit-learn/issues/21134. Be "
316+
"warned that the result of the calibration is likely to be "
317+
"incorrect."
318+
)
315319

316320
# Check that each cross-validation fold can have at least one
317321
# example per class
@@ -351,6 +355,11 @@ def fit(self, X, y, sample_weight=None):
351355
else:
352356
this_estimator = clone(base_estimator)
353357
_, method_name = _get_prediction_method(this_estimator)
358+
fit_params = (
359+
{"sample_weight": sample_weight}
360+
if sample_weight is not None and supports_sw
361+
else None
362+
)
354363
pred_method = partial(
355364
cross_val_predict,
356365
estimator=this_estimator,
@@ -359,6 +368,7 @@ def fit(self, X, y, sample_weight=None):
359368
cv=cv,
360369
method=method_name,
361370
n_jobs=self.n_jobs,
371+
fit_params=fit_params,
362372
)
363373
predictions = _compute_predictions(
364374
pred_method, method_name, X, n_classes
@@ -436,7 +446,9 @@ def _more_tags(self):
436446
return {
437447
"_xfail_checks": {
438448
"check_sample_weights_invariance": (
439-
"zero sample_weight is not equivalent to removing samples"
449+
"Due to the cross-validation and sample ordering, removing a sample"
450+
" is not strictly equal to putting is weight to zero. Specific unit"
451+
" tests are added for CalibratedClassifierCV specifically."
440452
),
441453
}
442454
}
@@ -760,10 +772,17 @@ def _sigmoid_calibration(predictions, y, sample_weight=None):
760772

761773
F = predictions # F follows Platt's notations
762774

763-
# Bayesian priors (see Platt end of section 2.2)
764-
prior0 = float(np.sum(y <= 0))
765-
prior1 = y.shape[0] - prior0
766-
T = np.zeros(y.shape)
775+
# Bayesian priors (see Platt end of section 2.2):
776+
# It corresponds to the number of samples, taking into account the
777+
# `sample_weight`.
778+
mask_negative_samples = y <= 0
779+
if sample_weight is not None:
780+
prior0 = (sample_weight[mask_negative_samples]).sum()
781+
prior1 = (sample_weight[~mask_negative_samples]).sum()
782+
else:
783+
prior0 = float(np.sum(mask_negative_samples))
784+
prior1 = y.shape[0] - prior0
785+
T = np.zeros_like(y, dtype=np.float64)
767786
T[y > 0] = (prior1 + 1.0) / (prior1 + 2.0)
768787
T[y <= 0] = 1.0 / (prior0 + 2.0)
769788
T1 = 1.0 - T

sklearn/tests/test_calibration.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numpy.testing import assert_allclose
77
from scipy import sparse
88

9-
from sklearn.base import BaseEstimator
9+
from sklearn.base import BaseEstimator, clone
1010
from sklearn.dummy import DummyClassifier
1111
from sklearn.model_selection import LeaveOneOut, train_test_split
1212

@@ -784,3 +784,97 @@ def test_calibration_display_ref_line(pyplot, iris_data_binary):
784784

785785
labels = viz2.ax_.get_legend_handles_labels()[1]
786786
assert labels.count("Perfectly calibrated") == 1
787+
788+
789+
@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
790+
@pytest.mark.parametrize("ensemble", [True, False])
791+
def test_calibrated_classifier_cv_double_sample_weights_equivalence(method, ensemble):
792+
"""Check that passing repeating twice the dataset `X` is equivalent to
793+
passing a `sample_weight` with a factor 2."""
794+
X, y = load_iris(return_X_y=True)
795+
# Scale the data to avoid any convergence issue
796+
X = StandardScaler().fit_transform(X)
797+
# Only use 2 classes
798+
X, y = X[:100], y[:100]
799+
sample_weight = np.ones_like(y) * 2
800+
801+
# Interlace the data such that a 2-fold cross-validation will be equivalent
802+
# to using the original dataset with a sample weights of 2
803+
X_twice = np.zeros((X.shape[0] * 2, X.shape[1]), dtype=X.dtype)
804+
X_twice[::2, :] = X
805+
X_twice[1::2, :] = X
806+
y_twice = np.zeros(y.shape[0] * 2, dtype=y.dtype)
807+
y_twice[::2] = y
808+
y_twice[1::2] = y
809+
810+
base_estimator = LogisticRegression()
811+
calibrated_clf_without_weights = CalibratedClassifierCV(
812+
base_estimator,
813+
method=method,
814+
ensemble=ensemble,
815+
cv=2,
816+
)
817+
calibrated_clf_with_weights = clone(calibrated_clf_without_weights)
818+
819+
calibrated_clf_with_weights.fit(X, y, sample_weight=sample_weight)
820+
calibrated_clf_without_weights.fit(X_twice, y_twice)
821+
822+
# Check that the underlying fitted estimators have the same coefficients
823+
for est_with_weights, est_without_weights in zip(
824+
calibrated_clf_with_weights.calibrated_classifiers_,
825+
calibrated_clf_without_weights.calibrated_classifiers_,
826+
):
827+
assert_allclose(
828+
est_with_weights.base_estimator.coef_,
829+
est_without_weights.base_estimator.coef_,
830+
)
831+
832+
# Check that the predictions are the same
833+
y_pred_with_weights = calibrated_clf_with_weights.predict_proba(X)
834+
y_pred_without_weights = calibrated_clf_without_weights.predict_proba(X)
835+
836+
assert_allclose(y_pred_with_weights, y_pred_without_weights)
837+
838+
839+
@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
840+
@pytest.mark.parametrize("ensemble", [True, False])
841+
def test_calibrated_classifier_cv_zeros_sample_weights_equivalence(method, ensemble):
842+
"""Check that passing removing some sample from the dataset `X` is
843+
equivalent to passing a `sample_weight` with a factor 0."""
844+
X, y = load_iris(return_X_y=True)
845+
# Scale the data to avoid any convergence issue
846+
X = StandardScaler().fit_transform(X)
847+
# Only use 2 classes and select samples such that 2-fold cross-validation
848+
# split will lead to an equivalence with a `sample_weight` of 0
849+
X = np.vstack((X[:40], X[50:90]))
850+
y = np.hstack((y[:40], y[50:90]))
851+
sample_weight = np.zeros_like(y)
852+
sample_weight[::2] = 1
853+
854+
base_estimator = LogisticRegression()
855+
calibrated_clf_without_weights = CalibratedClassifierCV(
856+
base_estimator,
857+
method=method,
858+
ensemble=ensemble,
859+
cv=2,
860+
)
861+
calibrated_clf_with_weights = clone(calibrated_clf_without_weights)
862+
863+
calibrated_clf_with_weights.fit(X, y, sample_weight=sample_weight)
864+
calibrated_clf_without_weights.fit(X[::2], y[::2])
865+
866+
# Check that the underlying fitted estimators have the same coefficients
867+
for est_with_weights, est_without_weights in zip(
868+
calibrated_clf_with_weights.calibrated_classifiers_,
869+
calibrated_clf_without_weights.calibrated_classifiers_,
870+
):
871+
assert_allclose(
872+
est_with_weights.base_estimator.coef_,
873+
est_without_weights.base_estimator.coef_,
874+
)
875+
876+
# Check that the predictions are the same
877+
y_pred_with_weights = calibrated_clf_with_weights.predict_proba(X)
878+
y_pred_without_weights = calibrated_clf_without_weights.predict_proba(X)
879+
880+
assert_allclose(y_pred_with_weights, y_pred_without_weights)

0 commit comments

Comments
 (0)
0