You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For CalibratedClassiferCV, if the training subset has less classes than the test subset in a cv, e.g.,:
y_test = [1 1 2 2 3 3]
y train = [1 1 1 2 2 2]
this classifier (from the ensemble) will only be fit for 2 classes. At prediction time, the output of this classifier will have shape (n_samples, 2) as it can only predict 2 classes. The 3rd class proba's will be filled with 0's, see line:
resulting in lower proba than expected. We also average the probas from the ensemble of classifier/calibrators by dividing by the number of classifier/calibrators in the ensemble, see:
This also gives a lower proba than expected, as one of the ensemble just gave all 0's.
Steps/Code to Reproduce
I am not sure of the best way to reproduce, but the code below results in a lower proba for class 3 than you would expect:
fromsklearn.datasetsimportmake_classificationfromsklearn.ensembleimportRandomForestClassifierfromsklearn.model_selectionimportKFoldfromsklearn.calibrationimportCalibratedClassifierCVX, _=make_classification(n_samples=12, n_features=4, n_classes=3,
n_clusters_per_class=1, random_state=7)
y= [1,1,1,2,2,2,1,1,2,2,3,3]
# Make class 3 easier to predictX[-2:,:] =np.abs(X[-2:,:]) *10clf=RandomForestClassifier()
splits=2kfold=KFold(n_splits=2)
calb_clf=CalibratedClassifierCV(clf, cv=kfold)
calb_clf.fit(X,y)
# Predict the last sample in Xcalb_clf.predict_proba(X[-1,:].reshape((1,-1)))
Gives: [[0.31342392 0.61173178 0.0748443 ]]
I would suggest maybe adding a warning whenever the train subset does not contain all the classes in y. Again not sure how much of a problem this is as it is probably uncommon to have less classes in train than the full present in y .
I'm not sure we want to care about this because users should use StratifiedKFolds, and in general I wouldn't expe
A144
ct calibration to be useful with so few examples which is when such discrepancies might happen
@NicolasHug thinking about this more, we specifically amended this function to allow cases where train and test subsets have different number of classes: #7799, but I don't think we deal with it well when it occurs...
I feel like we shouldn't have allowed this, if we aren't going to 'care' about it....?
Describe the bug
This may be very uncommon and not a big problem.
For
CalibratedClassiferCV
, if the training subset has less classes than the test subset in a cv, e.g.,:this classifier (from the ensemble) will only be fit for 2 classes. At prediction time, the output of this classifier will have shape (n_samples, 2) as it can only predict 2 classes. The 3rd class proba's will be filled with 0's, see line:
scikit-learn/sklearn/calibration.py
Line 367 in fd23727
resulting in lower proba than expected. We also average the probas from the ensemble of classifier/calibrators by dividing by the number of classifier/calibrators in the ensemble, see:
scikit-learn/sklearn/calibration.py
Line 214 in fd23727
This also gives a lower proba than expected, as one of the ensemble just gave all 0's.
Steps/Code to Reproduce
I am not sure of the best way to reproduce, but the code below results in a lower proba for class 3 than you would expect:
Gives:
[[0.31342392 0.61173178 0.0748443 ]]
I would suggest maybe adding a warning whenever the train subset does not contain all the classes in
y
. Again not sure how much of a problem this is as it is probably uncommon to have less classes in train than the full present iny
.cc @NicolasHug @ogrisel who have been reviewing other calibrator stuff.
The text was updated successfully, but these errors were encountered: