8000 TST check the legend instead of label names in CalibrationDisplay (#2… · scikit-learn/scikit-learn@05e7064 · GitHub
[go: up one dir, main page]

Skip to content

Commit 05e7064

Browse files
committed
TST check the legend instead of label names in CalibrationDisplay (#21697)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 2e57aca commit 05e7064

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

sklearn/calibration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,8 +1094,8 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
10941094
ax.plot([0, 1], [0, 1], "k:", label=ref_line_label)
10951095
self.line_ = ax.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[0]
10961096

1097-
if "label" in line_kwargs:
1098-
ax.legend(loc="lower right")
1097+
# We always have to show the legend for at least the reference line
1098+
ax.legend(loc="lower right")
10991099

11001100
ax.set(xlabel="Mean predicted probability", ylabel="Fraction of positives")
11011101

sklearn/tests/test_calibration.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -703,9 +703,14 @@ def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy)
703703
assert isinstance(viz.ax_, mpl.axes.Axes)
704704
assert isinstance(viz.figure_, mpl.figure.Figure)
705705

706-
assert viz.ax_.get_xlabel() == "Mean predicted probability"
707-
assert viz.ax_.get_ylabel() == "Fraction of positives"
708-
assert viz.line_.get_label() == "LogisticRegression"
706+
assert viz.ax_.get_xlabel() == "Mean predicted probability (Positive class: 1)"
707+
assert viz.ax_.get_ylabel() == "Fraction of positives (Positive class: 1)"
708+
709+
expected_legend_labels = ["LogisticRegression", "Perfectly calibrated"]
710+
legend_labels = viz.ax_.get_legend().get_texts()
711+
assert len(legend_labels) == len(expected_legend_labels)
712+
for labels in legend_labels:
713+
assert labels.get_text() in expected_legend_labels
709714

710715

711716
def test_plot_calibration_curve_pipeline(pyplot, iris_data_binary):
@@ -714,8 +719,12 @@ def test_plot_calibration_curve_pipeline(pyplot, iris_data_binary):
714719
clf = make_pipeline(StandardScaler(), LogisticRegression())
715720
clf.fit(X, y)
716721
viz = CalibrationDisplay.from_estimator(clf, X, y)
717-
assert clf.__class__.__name__ in viz.line_.get_label()
718-
assert viz.estimator_name == clf.__class__.__name__
722+
723+
expected_legend_labels = [viz.estimator_name, "Perfectly calibrated"]
724+
legend_labels = viz.ax_.get_legend().get_texts()
725+
assert len(legend_labels) == len(expected_legend_labels)
726+
for labels in legend_labels:
727+
assert labels.get_text() in expected_legend_labels
719728

720729

721730
@pytest.mark.parametrize(
@@ -728,7 +737,13 @@ def test_calibration_display_default_labels(pyplot, name, expected_label):
728737

729738
viz = CalibrationDisplay(prob_true, prob_pred, y_prob, estimator_name=name)
730739
viz.plot()
731-
assert viz.line_.get_label() == expected_label
740+
741+
expected_legend_labels = [] if name is None else [name]
742+
expected_legend_labels.append("Perfectly calibrated")
743+
legend_labels = viz.ax_.get_legend().get_texts()
744+
assert len(legend_labels) == len(expected_legend_labels)
745+
for labels in legend_labels:
746+
assert labels.get_text() in expected_legend_labels
732747

733748

734749
def test_calibration_display_label_class_plot(pyplot):
@@ -743,7 +758,12 @@ def test_calibration_display_label_class_plot(pyplot):
743758
assert viz.estimator_name == name
744759
name = "name two"
745760
viz.plot(name=name)
746-
assert viz.line_.get_label() == name
761+
762+
expected_legend_labels = [name, "Perfectly calibrated"]
763+
legend_labels = viz.ax_.get_legend().get_texts()
764+
assert len(legend_labels) == len(expected_legend_labels)
765+
for labels in legend_labels:
766+
assert labels.get_text() in expected_legend_labels
747767

748768

749769
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
@@ -766,11 +786,19 @@ def test_calibration_display_name_multiple_calls(
766786
assert viz.estimator_name == clf_name
767787
pyplot.close("all")
768788
viz.plot()
769-
assert clf_name == viz.line_.get_label()
789+
790+
expected_legend_labels = [clf_name, "Perfectly calibrated"]
791+
legend_labels = viz.ax_.get_legend().get_texts()
792+
assert len(legend_labels) == len(expected_legend_labels)
793+
for labels in legend_labels:
794+
assert labels.get_text() in expected_legend_labels
795+
770796
pyplot.close("all")
771797
clf_name = "another_name"
772798
viz.plot(name=clf_name)
773-
assert clf_name == viz.line_.get_label()
799+
assert len(legend_labels) == len(expected_legend_labels)
800+
for labels in legend_labels:
801+
assert labels.get_text() in expected_legend_labels
774802

775803

776804
def test_calibration_display_ref_line(pyplot, iris_data_binary):

0 commit comments

Comments
 (0)
0