From b2f75a5ff353028e82e4878b05d6e6a902a8fbc5 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 17 Nov 2021 14:57:58 +0100 Subject: [PATCH 1/2] Unify AxesSubplot string representation for labels across matplotlib versions This was changed in matplotlib 3.5.0 by: https://github.com/matplotlib/matplotlib/commit/e7e6c8cea7dd174751be335f89369aeb3c1f40f0#diff-501b7013d3efa42e08d1cc8dc7a27ee6944fcddb062cd7032249a0031bb01ff4R1641 --- sklearn/calibration.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 6131a8f759d1a..a9a981c52879e 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -1099,6 +1099,9 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs): ax.plot([0, 1], [0, 1], "k:", label=ref_line_label) self.line_ = ax.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[0] + # Unify AxesSubplot string representation for labels across matplotlib versions + self.line_.set_label(self.line_.get_label().replace("child", "line")) + if "label" in line_kwargs: ax.legend(loc="lower right") From de3b7422e015c43f5d11c4b02c5351bf5e1ca865 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 17 Nov 2021 17:37:19 +0100 Subject: [PATCH 2/2] Check for the integrity of the legend --- sklearn/calibration.py | 7 ++--- sklearn/tests/test_calibration.py | 49 ++++++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index a9a981c52879e..d5642033054ef 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -1099,11 +1099,8 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs): ax.plot([0, 1], [0, 1], "k:", label=ref_line_label) self.line_ = ax.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[0] - # Unify AxesSubplot string representation for labels across matplotlib versions - self.line_.set_label(self.line_.get_label().replace("child", "line")) - - if "label" in line_kwargs: - ax.legend(loc="lower right") + # We always have to show the legend for at least the reference line + ax.legend(loc="lower right") xlabel = f"Mean predicted probability {info_pos_label}" ylabel = f"Fraction of positives {info_pos_label}" diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index ee7214bf224e8..b89f0e8d73cc8 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -689,7 +689,12 @@ def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy) assert viz.ax_.get_xlabel() == "Mean predicted probability (Positive class: 1)" assert viz.ax_.get_ylabel() == "Fraction of positives (Positive class: 1)" - assert viz.line_.get_label() == "LogisticRegression" + + expected_legend_labels = ["LogisticRegression", "Perfectly calibrated"] + legend_labels = viz.ax_.get_legend().get_texts() + assert len(legend_labels) == len(expected_legend_labels) + for labels in legend_labels: + assert labels.get_text() in expected_legend_labels def test_plot_calibration_curve_pipeline(pyplot, iris_data_binary): @@ -698,8 +703,12 @@ def test_plot_calibration_curve_pipeline(pyplot, iris_data_binary): clf = make_pipeline(StandardScaler(), LogisticRegression()) clf.fit(X, y) viz = CalibrationDisplay.from_estimator(clf, X, y) - assert clf.__class__.__name__ in viz.line_.get_label() - assert viz.estimator_name == clf.__class__.__name__ + + expected_legend_labels = [viz.estimator_name, "Perfectly calibrated"] + legend_labels = viz.ax_.get_legend().get_texts() + assert len(legend_labels) == len(expected_legend_labels) + for labels in legend_labels: + assert labels.get_text() in expected_legend_labels @pytest.mark.parametrize( @@ -712,7 +721,13 @@ def test_calibration_display_default_labels(pyplot, name, expected_label): viz = CalibrationDisplay(prob_true, prob_pred, y_prob, estimator_name=name) viz.plot() - assert viz.line_.get_label() == expected_label + + expected_legend_labels = [] if name is None else [name] + expected_legend_labels.append("Perfectly calibrated") + legend_labels = viz.ax_.get_legend().get_texts() + assert len(legend_labels) == len(expected_legend_labels) + for labels in legend_labels: + assert labels.get_text() in expected_legend_labels def test_calibration_display_label_class_plot(pyplot): @@ -727,7 +742,12 @@ def test_calibration_display_label_class_plot(pyplot): assert viz.estimator_name == name name = "name two" viz.plot(name=name) - assert viz.line_.get_label() == name + + expected_legend_labels = [name, "Perfectly calibrated"] + legend_labels = viz.ax_.get_legend().get_texts() + assert len(legend_labels) == len(expected_legend_labels) + for labels in legend_labels: + assert labels.get_text() in expected_legend_labels @pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) @@ -750,11 +770,19 @@ def test_calibration_display_name_multiple_calls( assert viz.estimator_name == clf_name pyplot.close("all") viz.plot() - assert clf_name == viz.line_.get_label() + + expected_legend_labels = [clf_name, "Perfectly calibrated"] + legend_labels = viz.ax_.get_legend().get_texts() + assert len(legend_labels) == len(expected_legend_labels) + for labels in legend_labels: + assert labels.get_text() in expected_legend_labels + pyplot.close("all") clf_name = "another_name" viz.plot(name=clf_name) - assert clf_name == viz.line_.get_label() + assert len(legend_labels) == len(expected_legend_labels) + for labels in legend_labels: + assert labels.get_text() in expected_legend_labels def test_calibration_display_ref_line(pyplot, iris_data_binary): @@ -832,7 +860,12 @@ def test_calibration_display_pos_label( viz.ax_.get_ylabel() == f"Fraction of positives (Positive class: {expected_pos_label})" ) - assert viz.line_.get_label() == "LogisticRegression" + + expected_legend_labels = [lr.__class__.__name__, "Perfectly calibrated"] + legend_labels = viz.ax_.get_legend().get_texts() + assert len(legend_labels) == len(expected_legend_labels) + for labels in legend_labels: + assert labels.get_text() in expected_legend_labels @pytest.mark.parametrize("method", ["sigmoid", "isotonic"])