10000 ENH add pos_label to CalibrationDisplay (#21038) · scikit-learn/scikit-learn@90bef46 · GitHub
[go: up one dir, main page]

Skip to content

Commit 90bef46

Browse files
authored
ENH add pos_label to CalibrationDisplay (#21038)
1 parent 87838c2 commit 90bef46

File tree

3 files changed

+78
-11
lines changed

3 files changed

+78
-11
lines changed

doc/whats_new/v1.1.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ Changelog
4545
`pos_label` to specify the positive class label.
4646
:pr:`21032` by :user:`Guillaume Lemaitre <glemaitre>`.
4747

48+
- |Enhancement| :class:`CalibrationDisplay` accepts a parameter `pos_label` to
49+
add this information to the plot.
50+
:pr:`21038` by :user:`Guillaume Lemaitre <glemaitre>`.
51+
4852
:mod:`sklearn.cross_decomposition`
4953
..................................
5054

sklearn/calibration.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,13 @@ class CalibrationDisplay:
10151015
estimator_name : str, default=None
10161016
Name of estimator. If None, the estimator name is not shown.
10171017
1018+
pos_label : str or int, default=None
1019+
The positive class when computing the calibration curve.
1020+
By default, `estimators.classes_[1]` is considered as the
1021+
positive class.
1022+
1023+
.. versionadded:: 1.1
1024+
10181025
Attributes
10191026
----------
10201027
line_ : matplotlib Artist
@@ -1054,11 +1061,14 @@ class CalibrationDisplay:
10541061
<...>
10551062
"""
10561063

1057-
def __init__(self, prob_true, prob_pred, y_prob, *, estimator_name=None):
1064+
def __init__(
1065+
self, prob_true, prob_pred, y_prob, *, estimator_name=None, pos_label=None
1066+
):
10581067
self.prob_true = prob_true
10591068
self.prob_pred = prob_pred
10601069
self.y_prob = y_prob
10611070
self.estimator_name = estimator_name
1071+
self.pos_label = pos_label
10621072

10631073
def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
10641074
"""Plot visualization.
@@ -1095,6 +1105,9 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
10951105
fig, ax = plt.subplots()
10961106

10971107
name = self.estimator_name if name is None else name
1108+
info_pos_label = (
1109+
f"(Positive class: {self.pos_label})" if self.pos_label is not None else ""
1110+
)
10981111

10991112
line_kwargs = {}
11001113
if name is not None:
@@ -1110,7 +1123,9 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
11101123
if "label" in line_kwargs:
11111124
ax.legend(loc="lower right")
11121125

1113-
ax.set(xlabel="Mean predicted probability", ylabel="Fraction of positives")
1126+
xlabel = f"Mean predicted probability {info_pos_label}"
1127+
ylabel = f"Fraction of positives {info_pos_label}"
1128+
ax.set(xlabel=xlabel, ylabel=ylabel)
11141129

11151130
self.ax_ = ax
11161131
self.figure_ = ax.figure
@@ -1125,6 +1140,7 @@ def from_estimator(
11251140
*,
11261141
n_bins=5,
11271142
strategy="uniform",
1143+
pos_label=None,
11281144
name=None,
11291145
ref_line=True,
11301146
ax=None,
@@ -1170,6 +1186,13 @@ def from_estimator(
11701186
- `'quantile'`: The bins have the same number of samples and depend
11711187
on predicted probabilities.
11721188
1189+
pos_label : str or int, default=None
1190+
The positive class when computing the calibration curve.
1191+
By default, `estimators.classes_[1]` is considered as the
1192+
positive class.
1193+
1194+
.. versionadded:: 1.1
1195+
11731196
name : str, default=None
11741197
Name for labeling curve. If `None`, the name of the estimator is
11751198
used.
@@ -1217,10 +1240,8 @@ def from_estimator(
12171240
if not is_classifier(estimator):
12181241
raise ValueError("'estimator' should be a fitted classifier.")
12191242

1220-
# FIXME: `pos_label` should not be set to None
1221-
# We should allow any int or string in `calibration_curve`.
1222-
y_prob, _ = _get_response(
1223-
X, estimator, response_method="predict_proba", pos_label=None
1243+
y_prob, pos_label = _get_response(
1244+
X, estimator, response_method="predict_proba", pos_label=pos_label
12241245
)
12251246

12261247
name = name if name is not None else estimator.__class__.__name__
@@ -1229,6 +1250,7 @@ def from_estimator(
12291250
y_prob,
12301251
n_bins=n_bins,
12311252
strategy=strategy,
1253+
pos_label=pos_label,
12321254
name=name,
12331255
ref_line=ref_line,
12341256
ax=ax,
@@ -1243,6 +1265,7 @@ def from_predictions(
12431265
*,
12441266
n_bins=5,
12451267
strategy="uniform",
1268+
pos_label=None,
12461269
name=None,
12471270
ref_line=True,
12481271
ax=None,
@@ -1283,6 +1306,13 @@ def from_predictions(
12831306
- `'quantile'`: The bins have the same number of samples and depend
12841307
on predicted probabilities.
12851308
1309+
pos_label : str or int, default=None
1310+
The positive class when computing the calibration curve.
1311+
By default, `estimators.classes_[1]` is considered as the
1312+
positive class.
1313+
1314+
.. versionadded:: 1.1
1315+
12861316
name : str, default=None
12871317
Name for labeling curve.
12881318
@@ -1328,11 +1358,16 @@ def from_predictions(
13281358
check_matplotlib_support(method_name)
13291359

13301360
prob_true, prob_pred = calibration_curve(
1331-
y_true, y_prob, n_bins=n_bins, strategy=strategy
1361+
y_true, y_prob, n_bins=n_bins, strategy=strategy, pos_label=pos_label
13321362
)
1333-
name = name if name is not None else "Classifier"
1363+
name = "Classifier" if name is None else name
1364+
pos_label = _check_pos_label_consistency(pos_label, y_true)
13341365

13351366
disp = cls(
1336-
prob_true=prob_true, prob_pred=prob_pred, y_prob=y_prob, estimator_name=name
1367+
prob_true=prob_true,
1368+
prob_pred=prob_pred,
1369+
y_prob=y_prob,
1370+
estimator_name=name,
1371+
pos_label=pos_label,
13371372
)
13381373
return disp.plot(ax=ax, ref_line=ref_line, **kwargs)

sklearn/tests/test_calibration.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,8 @@ def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy)
703703
assert isinstance(viz.ax_, mpl.axes.Axes)
704704
assert isinstance(viz.figure_, mpl.figure.Figure)
705705

706-
assert viz.ax_.get_xlabel() == "Mean predicted probability"
707-
assert viz.ax_.get_ylabel() == "Fraction of positives"
706+
assert viz.ax_.get_xlabel() == "Mean predicted probability (Positive class: 1)"
707+
assert viz.ax_.get_ylabel() == "Fraction of positives (Positive class: 1)"
708708
assert viz.line_.get_label() == "LogisticRegression"
709709

710710

@@ -823,6 +823,34 @@ def test_calibration_curve_pos_label(dtype_y_str):
823823
assert_allclose(prob_true, [0, 0, 0.5, 1])
824824

825825

826+
@pytest.mark.parametrize("pos_label, expected_pos_label", [(None, 1), (0, 0), (1, 1)])
827+
def test_calibration_display_pos_label(
828+
pyplot, iris_data_binary, pos_label, expected_pos_label
829+
):
830+
"""Check the behaviour of `pos_label` in the `CalibrationDisplay`."""
831+
X, y = iris_data_binary
832+
833+
lr = LogisticRegression().fit(X, y)
834+
viz = CalibrationDisplay.from_estimator(lr, X, y, pos_label=pos_label)
835+
836+
y_prob = lr.predict_proba(X)[:, expected_pos_label]
837+
prob_true, prob_pred = calibration_curve(y, y_prob, pos_label=pos_label)
838+
839+
assert_allclose(viz.prob_true, prob_true)
840+
assert_allclose(viz.prob_pred, prob_pred)
841+
assert_allclose(viz.y_prob, y_prob)
842+
843+
assert (
844+
viz.ax_.get_xlabel()
845+
== f"Mean predicted probability (Positive class: {expected_pos_label})"
846+
)
847+
assert (
848+
viz.ax_.get_ylabel()
849+
== f"Fraction of positives (Positive class: {expected_pos_label})"
850+
)
851+
assert viz.line_.get_label() == "LogisticRegression"
852+
853+
826854
@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
827855
@pytest.mark.parametrize("ensemble", [True, False])
828856
def test_calibrated_classifier_cv_double_sample_weights_equivalence(method, ensemble):

0 commit comments

Comments
 (0)
0