8000 MAINT Introduce BinaryClassifierCurveDisplayMixin (#25969) · thomasjpfan/scikit-learn@7f1e15d · GitHub
[go: up one dir, main page]

Skip to content

Commit 7f1e15d

Browse files
MAINT Introduce BinaryClassifierCurveDisplayMixin (scikit-learn#25969)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 55f91cc commit 7f1e15d

12 files changed

+270
-277
lines changed

sklearn/calibration.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@
3030
from .utils import (
3131
column_or_1d,
3232
indexable,
33-
check_matplotlib_support,
3433
_safe_indexing,
3534
)
36-
from .utils._response import _get_response_values_binary
3735

38-
from .utils.multiclass import check_classification_targets, type_of_target
36+
from .utils.multiclass import check_classification_targets
3937
from .utils.parallel import delayed, Parallel
4038
from .utils._param_validation import StrOptions, HasMethods, Hidden
39+
from .utils._plotting import _BinaryClassifierCurveDisplayMixin
4140
from .utils.validation import (
4241
_check_fit_params,
42+
_check_pos_label_consistency,
4343
_check_sample_weight,
4444
_num_samples,
4545
check_consistent_length,
@@ -48,7 +48,6 @@
4848
from .isotonic import IsotonicRegression
4949
from .svm import LinearSVC
5050
from .model_selection import check_cv, cross_val_predict
51-
from .metrics._base import _check_pos_label_consistency
5251

5352

5453
class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
@@ -1013,7 +1012,7 @@ def calibration_curve(
10131012
return prob_true, prob_pred
10141013

10151014

1016-
class CalibrationDisplay:
1015+
class CalibrationDisplay(_BinaryClassifierCurveDisplayMixin):
10171016
"""Calibration curve (also known as reliability diagram) visualization.
10181017
10191018
It is recommended to use
@@ -1124,13 +1123,8 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
11241123
display : :class:`~sklearn.calibration.CalibrationDisplay`
11251124
Object that stores computed values.
11261125
"""
1127-
check_matplotlib_support("CalibrationDisplay.plot")
1128-
import matplotlib.pyplot as plt
1126+
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
11291127

1130-
if ax is None:
1131-
fig, ax = plt.subplots()
1132-
1133-
name = self.estimator_name if name is None else name
11341128
info_pos_label = (
11351129
f"(Positive class: {self.pos_label})" if self.pos_label is not None else ""
11361130
)
@@ -1141,20 +1135,20 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
11411135
line_kwargs.update(**kwargs)
11421136

11431137
ref_line_label = "Perfectly calibrated"
1144-
existing_ref_line = ref_line_label in ax.get_legend_handles_labels()[1]
1138+
existing_ref_line = ref_line_label in self.ax_.get_legend_handles_labels()[1]
11451139
if ref_line and not existing_ref_line:
1146-
ax.plot([0, 1], [0, 1], "k:", label=ref_line_label)
1147-
self.line_ = ax.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[0]
1140+
self.ax_.plot([0, 1], [0, 1], "k:", label=ref_line_label)
1141+
self.line_ = self.ax_.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[
1142+
0
1143+
]
11481144

11491145
# We always have to show the legend for at least the reference line
1150-
ax.legend(loc="lower right")
1146+
self.ax_.legend(loc="lower right")
11511147

11521148
xlabel = f"Mean predicted probability {info_pos_label}"
11531149
ylabel = f"Fraction of positives {info_pos_label}"
1154-
ax.set(xlabel=xlabel, ylabel=ylabel)
1150+
self.ax_.set(xlabel=xlabel, ylabel=ylabel)
11551151

1156-
self.ax_ = ax
1157-
self.figure_ = ax.figure
11581152
return self
11591153

11601154
@classmethod
@@ -1260,15 +1254,15 @@ def from_estimator(
12601254
>>> disp = CalibrationDisplay.from_estimator(clf, X_test, y_test)
12611255
>>> plt.show()
12621256
"""
1263-
method_name = f"{cls.__name__}.from_estimator"
1264-
check_matplotlib_support(method_name)
1265-
1266-
check_is_fitted(estimator)
1267-
y_prob, pos_label = _get_response_values_binary(
1268-
estimator, X, response_method="predict_proba", pos_label=pos_label
1257+
y_prob, pos_label, name = cls._validate_and_get_response_values(
1258+
estimator,
1259+
X,
1260+
y,
1261+
response_method="predict_proba",
1262+
pos_label=pos_label,
1263+
name=name,
12691264
)
12701265

1271-
name = name if name is not None else estimator.__class__.__name__
12721266
return cls.from_predictions(
12731267
y,
12741268
y_prob,
@@ -1378,26 +1372,19 @@ def from_predictions(
13781372
>>> disp = CalibrationDisplay.from_predictions(y_test, y_prob)
13791373
>>> plt.show()
13801374
"""
1381-
method_name = f"{cls.__name__}.from_predictions"
1382-
check_matplotlib_support(method_name)
1383-
1384-
target_type = type_of_target(y_true)
1385-
if target_type != "binary":
1386-
raise ValueError(
1387-
f"The target y is not binary. Got {target_type} type of target."
1388-
)
1375+
pos_label_validated, name = cls._validate_from_predictions_params(
1376+
y_true, y_prob, sample_weight=None, pos_label=pos_label, name=name
1377+
)
13891378

13901379
prob_true, prob_pred = calibration_curve(
13911380
y_true, y_prob, n_bins=n_bins, strategy=strategy, pos_label=pos_label
13921381
)
1393-
name = "Classifier" if name is None else name
1394-
pos_label = _check_pos_label_consistency(pos_label, y_true)
13951382

13961383
disp = cls(
13971384
prob_true=prob_true,
13981385
prob_pred=prob_pred,
13991386
y_prob=y_prob,
14001387
estimator_name=name,
1401-
pos_label=pos_label,
1388+
pos_label=pos_label_validated,
14021389
)
14031390
return disp.plot(ax=ax, ref_line=ref_line, **kwargs)

sklearn/metrics/_base.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -197,55 +197,3 @@ def _average_multiclass_ovo_score(binary_metric, y_true, y_score, average="macro
197197
pair_scores[ix] = (a_true_score + b_true_score) / 2
198198

199199
return np.average(pair_scores, weights=prevalence)
200-
201-
202-
def _check_pos_label_consistency(pos_label, y_true):
203-
"""Check if `pos_label` need to be specified or not.
204-
205-
In binary classification, we fix `pos_label=1` if the labels are in the set
206-
{-1, 1} or {0, 1}. Otherwise, we raise an error asking to specify the
207-
`pos_label` parameters.
208-
209-
Parameters
210-
----------
211-
pos_label : int, str or None
212-
The positive label.
213-
y_true : ndarray of shape (n_samples,)
214-
The target vector.
215-
216-
Returns
217-
-------
218-
pos_label : int
219-
If `pos_label` can be inferred, it will be returned.
220-
221-
Raises
222-
------
223-
ValueError
224-
In the case that `y_true` does not have label in {-1, 1} or {0, 1},
225-
it will raise a `ValueError`.
226-
"""
227-
# ensure binary classification if pos_label is not specified
228-
# classes.dtype.kind in ('O', 'U', 'S') is required to avoid
229-
# triggering a FutureWarning by calling np.array_equal(a, b)
230-
# when elements in the two arrays are not comparable.
231-
classes = np.unique(y_true)
232-
if pos_label is None and (
233-
classes.dtype.kind in "OUS"
234-
or not (
235-
np.array_equal(classes, [0, 1])
236-
or np.array_equal(classes, [-1, 1])
237-
or np.array_equal(classes, [0])
238-
or np.array_equal(classes, [-1])
239-
or np.array_equal(classes, [1])
240-
)
241-
):
242-
classes_repr = ", ".join(repr(c) for c in classes)
243-
raise ValueError(
244-
f"y_true takes value in {{{classes_repr}}} and pos_label is not "
245-
"specified: either make y_true take value in {0, 1} or "
246-
"{-1, 1} or pass pos_label explicitly."
247-
)
248-
elif pos_label is None:
249-
pos_label = 1
250-
251-
return pos_label

sklearn/metrics/_classification.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,11 @@
4040
from ..utils.extmath import _nanaverage
4141
from ..utils.multiclass import unique_labels
4242
from ..utils.multiclass import type_of_target
43-
from ..utils.validation import _num_samples
43+
from ..utils.validation import _check_pos_label_consistency, _num_samples
4444
from ..utils.sparsefuncs import count_nonzero
4545
from ..utils._param_validation import StrOptions, Options, Interval, validate_params
4646
from ..exceptions import UndefinedMetricWarning
4747

48-
from ._base import _check_pos_label_consistency
49-
5048

5149
def _check_zero_division(zero_division):
5250
if isinstance(zero_division, str) and zero_division == "warn":

sklearn/metrics/_plot/det_curve.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
import scipy as sp
22

33
from .. import det_curve
4-
from .._base import _check_pos_label_consistency
4+
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
55

6-
from ...utils import check_matplotlib_support
7-
from ...utils._response import _get_response_values_binary
86

9-
10-
class DetCurveDisplay:
7+
class DetCurveDisplay(_BinaryClassifierCurveDisplayMixin):
118
"""DET curve visualization.
129
1310
It is recommend to use :func:`~sklearn.metrics.DetCurveDisplay.from_estimator`
@@ -163,15 +160,13 @@ def from_estimator(
163160
<...>
164161
>>> plt.show()
165162
"""
166-
check_matplotlib_support(f"{cls.__name__}.from_estimator")
167-
168-
name = estimator.__class__.__name__ if name is None else name
169-
170-
y_pred, pos_label = _get_response_values_binary(
163+
y_pred, pos_label, name = cls._validate_and_get_response_values(
171164
estimator,
172165
X,
173-
response_method,
166+
y,
167+
response_method=response_method,
174168
pos_label=pos_label,
169+
name=name,
175170
)
176171

177172
return cls.from_predictions(
@@ -259,22 +254,22 @@ def from_predictions(
259254
<...>
260255
>>> plt.show()
261256
"""
262-
check_matplotlib_support(f"{cls.__name__}.from_predictions")
257+
pos_label_validated, name = cls._validate_from_predictions_params(
258+
y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name
259+
)
260+
263261
fpr, fnr, _ = det_curve(
264262
y_true,
265263
y_pred,
266264
pos_label=pos_label,
267265
sample_weight=sample_weight,
268266
)
269267

270-
pos_label = _check_pos_label_consistency(pos_label, y_true)
271-
name = "Classifier" if name is None else name
272-
273268
viz = DetCurveDisplay(
274269
fpr=fpr,
275270
fnr=fnr,
276271
estimator_name=name,
277-
pos_label=pos_label,
272+
pos_label=pos_label_validated,
278273
)
279274

280275
return viz.plot(ax=ax, name=name, **kwargs)
@@ -300,18 +295,12 @@ def plot(self, ax=None, *, name=None, **kwargs):
300295
display : :class:`~sklearn.metrics.plot.DetCurveDisplay`
301296
Object that stores computed values.
302297
"""
303-
check_matplotlib_support("DetCurveDisplay.plot")
298+
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
304299

305-
name = self.estimator_name if name is None else name
306300
line_kwargs = {} if name is None else {"label": name}
307301
line_kwargs.update(**kwargs)
308302

309-
import matplotlib.pyplot as plt
310-
311-
if ax is None:
312-
_, ax = plt.subplots()
313-
314-
(self.line_,) = ax.plot(
303+
(self.line_,) = self.ax_.plot(
315304
sp.stats.norm.ppf(self.fpr),
316305
sp.stats.norm.ppf(self.fnr),
317306
**line_kwargs,
@@ -322,24 +311,22 @@ def plot(self, ax=None, *, name=None, **kwargs):
322311

323312
xlabel = "False Positive Rate" + info_pos_label
324313
ylabel = "False Negative Rate" + info_pos_label
325-
ax.set(xlabel=xlabel, ylabel=ylabel)
314+
self.ax_.set(xlabel=xlabel, ylabel=ylabel)
326315

327316
if "label" in line_kwargs:
328-
ax.legend(loc="lower right")
317+
self.ax_.legend(loc="lower right")
329318

330319
ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999]
331320
tick_locations = sp.stats.norm.ppf(ticks)
332321
tick_labels = [
333322
"{:.0%}".format(s) if (100 * s).is_integer() else "{:.1%}".format(s)
334323
for s in ticks
335324
]
336-
ax.set_xticks(tick_locations)
337-
ax.set_xticklabels(tick_labels)
338-
ax.set_xlim(-3, 3)
339-
ax.set_yticks(tick_locations)
340-
ax.set_yticklabels(tick_labels)
341-
ax.set_ylim(-3, 3)
342-
343-
self.ax_ = ax
344-
self.figure_ = ax.figure
325+
self.ax_.set_xticks(tick_locations)
326+
self.ax_.set_xticklabels(tick_labels)
327+
self.ax_.set_xlim(-3, 3)
328+
self.ax_.set_yticks(tick_locations)
329+
self.ax_.set_yticklabels(tick_labels)
330+
self.ax_.set_ylim(-3, 3)
331+
345332
return self

0 commit comments

Comments
 (0)
0