8000 TST check the legend instead of label names in CalibrationDisplay by jjerphan · Pull Request #21697 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
8000

TST check the legend instead of label names in CalibrationDisplay #21697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,8 +1099,8 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
ax.plot([0, 1], [0, 1], "k:", label=ref_line_label)
self.line_ = ax.plot(self.prob_pred, self.prob_true, "s-", **line_kwargs)[0]

if "label" in line_kwargs:
ax.legend(loc="lower right")
# We always have to show the legend for at least the reference line
ax.legend(loc="lower right")

xlabel = f"Mean predicted probability {info_pos_label}"
ylabel = f"Fraction of positives {info_pos_label}"
Expand Down
49 changes: 41 additions & 8 deletions sklearn/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,12 @@ def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy)

assert viz.ax_.get_xlabel() == "Mean predicted probability (Positive class: 1)"
assert viz.ax_.get_ylabel() == "Fraction of positives (Positive class: 1)"
assert viz.line_.get_label() == "LogisticRegression"

expected_legend_labels = ["LogisticRegression", "Perfectly calibrated"]
legend_labels = viz.ax_.get_legend().get_texts()
assert len(legend_labels) == len(expected_legend_labels)
for labels in legend_labels:
assert labels.get_text() in expected_legend_labels


def test_plot_calibration_curve_pipeline(pyplot, iris_data_binary):
Expand All @@ -698,8 +703,12 @@ def test_plot_calibration_curve_pipeline(pyplot, iris_data_binary):
clf = make_pipeline(StandardScaler(), LogisticRegression())
clf.fit(X, y)
viz = CalibrationDisplay.from_estimator(clf, X, y)
assert clf.__class__.__name__ in viz.line_.get_label()
assert viz.estimator_name == clf.__class__.__name__

expected_legend_labels = [viz.estimator_name, "Perfectly calibrated"]
legend_labels = viz.ax_.get_legend().get_texts()
assert len(legend_labels) == len(expected_legend_labels)
for labels in legend_labels:
assert labels.get_text() in expected_legend_labels


@pytest.mark.parametrize(
Expand All @@ -712,7 +721,13 @@ def test_calibration_display_default_labels(pyplot, name, expected_label):

viz = CalibrationDisplay(prob_true, prob_pred, y_prob, estimator_name=name)
viz.plot()
assert viz.line_.get_label() == expected_label

expected_legend_labels = [] if name is None else [name]
expected_legend_labels.append("Perfectly calibrated")
legend_labels = viz.ax_.get_legend().get_texts()
assert len(legend_labels) == len(expected_legend_labels)
for labels in legend_labels:
assert labels.get_text() in expected_legend_labels


def test_calibration_display_label_class_plot(pyplot):
Expand All @@ -727,7 +742,12 @@ def test_calibration_display_label_class_plot(pyplot):
assert viz.estimator_name == name
name = "name two"
viz.plot(name=name)
assert viz.line_.get_label() == name

expected_legend_labels = [name, "Perfectly calibrated"]
legend_labels = viz.ax_.get_legend().get_texts()
assert len(legend_labels) == len(expected_legend_labels)
for labels in legend_labels:
assert labels.get_text() in expected_legend_labels


@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
Expand All @@ -750,11 +770,19 @@ def test_calibration_display_name_multiple_calls(
assert viz.estimator_name == clf_name
pyplot.close("all")
viz.plot()
assert clf_name == viz.line_.get_label()

expected_legend_labels = [clf_name, "Perfectly calibrated"]
legend_labels = viz.ax_.get_legend().get_texts()
assert len(legend_labels) == len(expected_legend_labels)
for labels in legend_labels:
assert labels.get_text() in expected_legend_labels

pyplot.close("all")
clf_name = "another_name"
viz.plot(name=clf_name)
assert clf_name == viz.line_.get_label()
assert len(legend_labels) == len(expected_legend_labels)
for labels in legend_labels:
assert labels.get_text() in expected_legend_labels


def test_calibration_display_ref_line(pyplot, iris_data_binary):
Expand Down Expand Up @@ -832,7 +860,12 @@ def test_calibration_display_pos_label(
viz.ax_.get_ylabel()
== f"Fraction of positives (Positive class: {expected_pos_label})"
)
assert viz.line_.get_label() == "LogisticRegression"

expected_legend_labels = [lr.__class__.__name__, "Perfectly calibrated"]
legend_labels = viz.ax_.get_legend().get_texts()
assert len(legend_labels) == len(expected_legend_labels)
for labels in legend_labels:
assert labels.get_text() in expected_legend_labels


@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
Expand Down
0