8000 CalibratedClassifierCV does not handle well sample_weight when ensemble=False · Issue #20610 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
CalibratedClassifierCV does not handle well sample_weight when ensemble=False #20610
Closed
@JulienB-78

Description

@JulienB-78
< 9425 div class="Box-sc-g0xbh4-0 dnyPuu">

CalibratedClassifierCV does not handle well sample_weight with ensemble=False

In the fit method, sample_weight is not passed to cross_val_predict to generate the prediction scores (https://github.com/scikit-learn/scikit-learn/blob/2beed5584/sklearn/calibration.py#L325) whereas it is passed to fit when the classifier is refitted on the entire dataset (https://github.com/scikit-learn/scikit-learn/blob/2beed5584/sklearn/calibration.py#L328).

It makes the calibration to fail as the assumption that the classifiers built in each cv split of cross_val_predict has a similar behaviour as the one trained on the whole dataset at the end.

To correct the bug, I suggest to pass sample_weight to cross_val_predict using the fit_params dictionary

pred_method = partial(
                    cross_val_predict, estimator=this_estimator, X=X, y=y,
                    cv=cv, method=method_name, n_jobs=self.n_jobs,
                    fit_params={"sample_weight": sample_weight}
                )

Example to reproduce the issue:

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

from sklearn.datasets import make_hastie_10_2

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.calibration import CalibratedClassifierCV

X, y = make_hastie_10_2(50000)
y[y == -1] = 0

X_0 = X[y == 0, :]
X_1 = X[y == 1, :]

y_0 = y[y == 0]
y_1 = y[y == 1]

# Discard half of the sample with y==0
X = np.vstack([X_0[:int(len(y_1) / 2), :], X_1])
y = np.hstack([y_0[:int(len(y_1) / 2)], y_1])

# Compute weigths to 'unbalance' the dataset'
weight = (y==0) + 1

X_train, X_test, y_train, y_test, weight_train, weight_test = train_test_split(X, y, weight)

calib = CalibratedClassifierCV(RandomForestClassifier(n_estimators=5, max_depth=3), ensemble=False, n_jobs=-1)
calib.fit(X_train, y_train, sample_weight=weight_train)

pred = calib.predict_proba(X_test)[:, 1]

# Check calibration in a way which takes into account that both classes have equal importance despite
# class 0 being less frequent

df_target_pred = pd.DataFrame([y_test, pred]).transpose()
df_target_pred.columns = ["target", "pred"]

hist_0 = np.histogram(df_target_pred.loc[df_target_pred.target == 0, 'pred'], bins=np.linspace(0, 1, 6), density=True)
hist_1 = np.histogram(df_target_pred.loc[df_target_pred.target == 1,  'pred'], bins=np.linspace(0, 1, 6), density=True)

plt.bar(hist_0[1][:-1], hist_0[0], align='edge', label='0', alpha=0.5, width=0.2)
plt.bar(hist_1[1][:-1], hist_1[0], align='edge', label='1', alpha=0.5, width=0.2)
plt.ylabel('Prediction histograms')
plt.xlabel('Predicted score')
plt.legend()
ax2 = plt.gca().twinx()
ax2.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated")
ax2.plot((hist_0[1][:-1] + hist_0[1][1:]) / 2, hist_1[0] / (hist_0[0] + hist_1[0] + 1e-10))
ax2.grid(False)
ax2.set_ylabel('Fraction of positives')
plt.show()

Versions

System:
python: 3.7.10 (default, Feb 26 2021, 13:06:18) [MSC v.1916 64 bit (AMD64)]
executable: C:\HOMEWARE\Anaconda3-Windows-x86_64\envs\python37\python.exe
machine: Windows-10-10.0.18362-SP0

Python dependencies:
pip: 21.1.3
setuptools: 52.0.0.post20210125
sklearn: 0.24.2
numpy: 1.20.2
scipy: 1.6.2
Cython: None
pandas: 1.2.5
matplotlib: 3.3.4
joblib: 1.0.1
threadpoolctl: 2.1.0

Built with OpenMP: True

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0