@@ -703,9 +703,14 @@ def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy)
703
703
assert isinstance (viz .ax_ , mpl .axes .Axes )
704
704
assert isinstance (viz .figure_ , mpl .figure .Figure )
705
705
706
- assert viz .ax_ .get_xlabel () == "Mean predicted probability"
707
- assert viz .ax_ .get_ylabel () == "Fraction of positives"
708
- assert viz .line_ .get_label () == "LogisticRegression"
706
+ assert viz .ax_ .get_xlabel () == "Mean predicted probability (Positive class: 1)"
707
+ assert viz .ax_ .get_ylabel () == "Fraction of positives (Positive class: 1)"
708
+
709
+ expected_legend_labels = ["LogisticRegression" , "Perfectly calibrated" ]
710
+ legend_labels = viz .ax_ .get_legend ().get_texts ()
711
+ assert len (legend_labels ) == len (expected_legend_labels )
712
+ for labels in legend_labels :
713
+ assert labels .get_text () in expected_legend_labels
709
714
710
715
711
716
def test_plot_calibration_curve_pipeline (pyplot , iris_data_binary ):
@@ -714,8 +719,12 @@ def test_plot_calibration_curve_pipeline(pyplot, iris_data_binary):
714
719
clf = make_pipeline (StandardScaler (), LogisticRegression ())
715
720
clf .fit (X , y )
716
721
viz = CalibrationDisplay .from_estimator (clf , X , y )
717
- assert clf .__class__ .__name__ in viz .line_ .get_label ()
718
- assert viz .estimator_name == clf .__class__ .__name__
722
+
723
+ expected_legend_labels = [viz .estimator_name , "Perfectly calibrated" ]
724
+ legend_labels = viz .ax_ .get_legend ().get_texts ()
725
+ assert len (legend_labels ) == len (expected_legend_labels )
726
+ for labels in legend_labels :
727
+ assert labels .get_text () in expected_legend_labels
719
728
720
729
721
730
@pytest .mark .parametrize (
@@ -728,7 +737,13 @@ def test_calibration_display_default_labels(pyplot, name, expected_label):
728
737
729
738
viz = CalibrationDisplay (prob_true , prob_pred , y_prob , estimator_name = name )
730
739
viz .plot ()
731
- assert viz .line_ .get_label () == expected_label
740
+
741
+ expected_legend_labels = [] if name is None else [name ]
742
+ expected_legend_labels .append ("Perfectly calibrated" )
743
+ legend_labels = viz .ax_ .get_legend ().get_texts ()
744
+ assert len (legend_labels ) == len (expected_legend_labels )
745
+ for labels in legend_labels :
746
+ assert labels .get_text () in expected_legend_labels
732
747
733
748
734
749
def test_calibration_display_label_class_plot (pyplot ):
@@ -743,7 +758,12 @@ def test_calibration_display_label_class_plot(pyplot):
743
758
assert viz .estimator_name == name
744
759
name = "name two"
745
760
viz .plot (name = name )
746
- assert viz .line_ .get_label () == name
761
+
762
+ expected_legend_labels = [name , "Perfectly calibrated" ]
763
+ legend_labels = viz .ax_ .get_legend ().get_texts ()
764
+ assert len (legend_labels ) == len (expected_legend_labels )
765
+ for labels in legend_labels :
766
+ assert labels .get_text () in expected_legend_labels
747
767
748
768
749
769
@pytest .mark .parametrize ("constructor_name" , ["from_estimator" , "from_predictions" ])
@@ -766,11 +786,19 @@ def test_calibration_display_name_multiple_calls(
766
786
assert viz .estimator_name == clf_name
767
787
pyplot .close ("all" )
768
788
viz .plot ()
769
- assert clf_name == viz .line_ .get_label ()
789
+
790
+ expected_legend_labels = [clf_name , "Perfectly calibrated" ]
791
+ legend_labels = viz .ax_ .get_legend ().get_texts ()
792
+ assert len (legend_labels ) == len (expected_legend_labels )
793
+ for labels in legend_labels :
794
+ assert labels .get_text () in expected_legend_labels
795
+
770
796
pyplot .close ("all" )
771
797
clf_name = "another_name"
772
798
viz .plot (name = clf_name )
773
- assert clf_name == viz .line_ .get_label ()
799
+ assert len (legend_labels ) == len (expected_legend_labels )
800
+ for labels in legend_labels :
801
+ assert labels .get_text () in expected_legend_labels
774
802
775
803
776
804
def test_calibration_display_ref_line (pyplot , iris_data_binary ):
0 commit comments