@@ -1015,6 +1015,13 @@ class CalibrationDisplay:
1015
1015
estimator_name : str, default=None
1016
1016
Name of estimator. If None, the estimator name is not shown.
1017
1017
1018
+ pos_label : str or int, default=None
1019
+ The positive class when computing the calibration curve.
1020
+ By default, `estimators.classes_[1]` is considered as the
1021
+ positive class.
1022
+
1023
+ .. versionadded:: 1.1
1024
+
1018
1025
Attributes
1019
1026
----------
1020
1027
line_ : matplotlib Artist
@@ -1054,11 +1061,14 @@ class CalibrationDisplay:
1054
1061
<...>
1055
1062
"""
1056
1063
1057
- def __init__ (self , prob_true , prob_pred , y_prob , * , estimator_name = None ):
1064
+ def __init__ (
1065
+ self , prob_true , prob_pred , y_prob , * , estimator_name = None , pos_label = None
1066
+ ):
1058
1067
self .prob_true = prob_true
1059
1068
self .prob_pred = prob_pred
1060
1069
self .y_prob = y_prob
1061
1070
self .estimator_name = estimator_name
1071
+ self .pos_label = pos_label
1062
1072
1063
1073
def plot (self , * , ax = None , name = None , ref_line = True , ** kwargs ):
1064
1074
"""Plot visualization.
@@ -1095,6 +1105,9 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
1095
1105
fig , ax = plt .subplots ()
1096
1106
1097
1107
name = self .estimator_name if name is None else name
1108
+ info_pos_label = (
1109
+ f"(Positive class: { self .pos_label } )" if self .pos_label is not None else ""
1110
+ )
1098
1111
1099
1112
line_kwargs = {}
1100
1113
if name is not None :
@@ -1110,7 +1123,9 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
1110
1123
if "label" in line_kwargs :
1111
1124
ax .legend (loc = "lower right" )
1112
1125
1113
- ax .set (xlabel = "Mean predicted probability" , ylabel = "Fraction of positives" )
1126
+ xlabel = f"Mean predicted probability { info_pos_label } "
1127
+ ylabel = f"Fraction of positives { info_pos_label } "
1128
+ ax .set (xlabel = xlabel , ylabel = ylabel )
1114
1129
1115
1130
self .ax_ = ax
1116
1131
self .figure_ = ax .figure
@@ -1125,6 +1140,7 @@ def from_estimator(
1125
1140
* ,
1126
1141
n_bins = 5 ,
1127
1142
strategy = "uniform" ,
1143
+ pos_label = None ,
1128
1144
name = None ,
1129
1145
ref_line = True ,
1130
1146
ax = None ,
@@ -1170,6 +1186,13 @@ def from_estimator(
1170
1186
- `'quantile'`: The bins have the same number of samples and depend
1171
1187
on predicted probabilities.
1172
1188
1189
+ pos_label : str or int, default=None
1190
+ The positive class when computing the calibration curve.
1191
+ By default, `estimators.classes_[1]` is considered as the
1192
+ positive class.
1193
+
1194
+ .. versionadded:: 1.1
1195
+
1173
1196
name : str, default=None
1174
1197
Name for labeling curve. If `None`, the name of the estimator is
1175
1198
used.
@@ -1217,10 +1240,8 @@ def from_estimator(
1217
1240
if not is_classifier (estimator ):
1218
1241
raise ValueError ("'estimator' should be a fitted classifier." )
1219
1242
1220
- # FIXME: `pos_label` should not be set to None
1221
- # We should allow any int or string in `calibration_curve`.
1222
- y_prob , _ = _get_response (
1223
- X , estimator , response_method = "predict_proba" , pos_label = None
1243
+ y_prob , pos_label = _get_response (
1244
+ X , estimator , response_method = "predict_proba" , pos_label = pos_label
1224
1245
)
1225
1246
1226
1247
name = name if name is not None else estimator .__class__ .__name__
@@ -1229,6 +1250,7 @@ def from_estimator(
1229
1250
y_prob ,
1230
1251
n_bins = n_bins ,
1231
1252
strategy = strategy ,
1253
+ pos_label = pos_label ,
1232
1254
name = name ,
1233
1255
ref_line = ref_line ,
1234
1256
ax = ax ,
@@ -1243,6 +1265,7 @@ def from_predictions(
1243
1265
* ,
1244
1266
n_bins = 5 ,
1245
1267
strategy = "uniform" ,
1268
+ pos_label = None ,
1246
1269
name = None ,
1247
1270
ref_line = True ,
1248
1271
ax = None ,
@@ -1283,6 +1306,13 @@ def from_predictions(
1283
1306
- `'quantile'`: The bins have the same number of samples and depend
1284
1307
on predicted probabilities.
1285
1308
1309
+ pos_label : str or int, default=None
1310
+ The positive class when computing the calibration curve.
1311
+ By default, `estimators.classes_[1]` is considered as the
1312
+ positive class.
1313
+
1314
+ .. versionadded:: 1.1
1315
+
1286
1316
name : str, default=None
1287
1317
Name for labeling curve.
1288
1318
@@ -1328,11 +1358,16 @@ def from_predictions(
1328
1358
check_matplotlib_support (method_name )
1329
1359
1330
1360
prob_true , prob_pred = calibration_curve (
1331
- y_true , y_prob , n_bins = n_bins , strategy = strategy
1361
+ y_true , y_prob , n_bins = n_bins , strategy = strategy , pos_label = pos_label
1332
1362
)
1333
- name = name if name is not None else "Classifier"
1363
+ name = "Classifier" if name is None else name
1364
+ pos_label = _check_pos_label_consistency (pos_label , y_true )
1334
1365
1335
1366
disp = cls (
1336
- prob_true = prob_true , prob_pred = prob_pred , y_prob = y_prob , estimator_name = name
1367
+ prob_true = prob_true ,
1368
+ prob_pred = prob_pred ,
1369
+ y_prob = y_prob ,
1370
+ estimator_name = name ,
1371
+ pos_label = pos_label ,
1337
1372
)
1338
1373
return disp .plot (ax = ax , ref_line = ref_line , ** kwargs )
0 commit comments