From 21832afbb73886219f6c763d4b60289f81f8d44b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 18 Jul 2021 00:13:12 +0200 Subject: [PATCH 01/25] ENH add from_estimator and from_preditions to PredictionRecallDisplay --- .../metrics/_plot/precision_recall_curve.py | 206 +++++++++++++++++- 1 file changed, 205 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 00937950a40e9..d71addd0465f4 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,9 +1,11 @@ +from sklearn.base import is_classifier from .base import _get_response from .. import average_precision_score from .. import precision_recall_curve +from .._base import _check_pos_label_consistency -from ...utils import check_matplotlib_support +from ...utils import check_matplotlib_support, deprecated class PrecisionRecallDisplay: @@ -144,7 +146,203 @@ def plot(self, ax=None, *, name=None, **kwargs): self.figure_ = ax.figure return self + @classmethod + def from_estimator( + cls, + estimator, + X, + y, + *, + sample_weight=None, + response_method="auto", + name=None, + pos_label=None, + ax=None, + **kwargs, + ): + """Plot precision-recall curve given an estimator and some data. + + Parameters + ---------- + estimator : estimator instance + Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a classifier. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y : array-like of shape (n_samples,) + Target values. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + response_method : {'predict_proba', 'decision_function', 'auto'}, \ + default='auto' + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + name : str, default=None + Name for labeling curve. If `None`, no name is used. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + + pos_label : str or int, default=None + The class considered as the positive class when computing the + precision and recall metrics. By default, `estimators.classes_[1]` + is considered as the positive class. + + **kwargs : dict + Keyword arguments to be passed to matplotlib's `plot`. + + Returns + ------- + display : :class:`~sklearn.metrics.PrecisionRecallDisplay` + + See Also + -------- + PrecisionRecallDisplay.from_predictions : Plot precision-recall curve + using estimated probabilities or output of decision function. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import PrecisionRecallDisplay + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(random_state=0) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, random_state=0) + >>> clf = LogisticRegression() + >>> clf.fit(X_train, y_train) + LogisticRegression() + >>> PrecisionRecallDisplay.from_estimator( + ... clf, X_test, y_test) + <...> + >>> plt.show() + """ + method_name = f"{cls.__name__}.from_estimator" + check_matplotlib_support(method_name) + if not is_classifier(estimator): + raise ValueError(f"{method_name} requires a classifier") + y_pred, pos_label = _get_response( + X, + estimator, + response_method, + pos_label=pos_label, + ) + + name = name if name is not None else estimator.__class__.__name__ + + return cls.from_predictions( + y, + y_pred, + sample_weight=sample_weight, + name=name, + pos_label=pos_label, + ax=ax, + **kwargs, + ) + + @classmethod + def from_predictions( + cls, + y_true, + y_pred, + *, + sample_weight=None, + name=None, + pos_label=None, + ax=None, + **kwargs, + ): + """Plot precision-recall curve given binary class predictions. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True binary labels. + + y_pred : array-like of shape (n_samples,) + Estimated probabilities or output of decision function. + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + name : str, default=None + Name for labeling curve. If `None`, no name is used. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + + pos_label : str or int, default=None + The class considered as the positive class when computing the + precision and recall metrics. By default, `estimators.classes_[1]` + is considered as the positive class. + + **kwargs : dict + Keyword arguments to be passed to matplotlib's `plot`. + + Returns + ------- + display : :class:`~sklearn.metrics.PrecisionRecallDisplay` + + See Also + -------- + PrecisionRecallDisplay.from_estimator : Plot precision-recall curve + using an estimator. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import PrecisionRecallDisplay + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(random_state=0) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, random_state=0) + >>> clf = LogisticRegression() + >>> clf.fit(X_train, y_train) + LogisticRegression() + >>> y_pred = clf.predict_proba(X_test)[:, 1] + >>> PrecisionRecallDisplay.from_predictions( + ... y_test, y_pred) + <...> + >>> plt.show() + """ + check_matplotlib_support(f"{cls.__name__}.from_predictions") + + pos_label = _check_pos_label_consistency(pos_label, y_true) + + precision, recall, _ = precision_recall_curve( + y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight + ) + average_precision = average_precision_score( + y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight + ) + + viz = PrecisionRecallDisplay( + precision=precision, + recall=recall, + average_precision=average_precision, + estimator_name=name, + pos_label=pos_label, + ) + + return viz.plot(ax=ax, name=name, **kwargs) + + +@deprecated( + "Function `plot_precision_recall_curve` is deprecated in 1.0 and will be " + "removed in 1.2. Use one of the class methods: " + "PrecisionRecallDisplay.from_predictions or " + "PrecisionRecallDisplay.from_estimator." +) def plot_precision_recall_curve( estimator, X, @@ -163,6 +361,12 @@ def plot_precision_recall_curve( Read more in the :ref:`User Guide `. + .. deprecated:: 1.0 + `plot_precision_recall_curve` is deprecated in 1.0 and will be removed in + 1.2. Use one of the following class methods: + :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` or + :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator`. + Parameters ---------- estimator : estimator instance From 4468c7a06e8cf649ab00c4cdde62bc2001af62ce Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 09:29:45 +0200 Subject: [PATCH 02/25] iter --- setup.cfg | 2 +- sklearn/metrics/_plot/tests/test_plot_precision_recall.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 8ee90da7436c0..107a8abff8182 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,7 @@ addopts = --ignore maint_tools --ignore asv_benchmarks --doctest-modules - --disable-pytest-warnings + # --disable-pytest-warnings -rxXs filterwarnings = diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 8ccc9125c4cf8..a4bd0267ac695 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -25,6 +25,9 @@ ) +@pytest.mark.filterwarnings( + "ignore: Function `plot_precision_recall_curve` is deprecated" +) def test_errors(pyplot): X, y_multiclass = make_classification( n_classes=3, n_samples=50, n_informative=3, random_state=0 From 7ea17685a13da5f823fe58018f6a38d9f673a7c9 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 11:12:40 +0200 Subject: [PATCH 03/25] TST add some common test and deprecation check --- .../metrics/_plot/precision_recall_curve.py | 8 ++++- .../tests/test_confusion_matrix_display.py | 27 --------------- .../_plot/tests/test_display_common.py | 34 +++++++++++++++++++ .../_plot/tests/test_plot_precision_recall.py | 9 +++-- .../tests/test_precision_recall_display.py | 24 +++++++++++++ 5 files changed, 69 insertions(+), 33 deletions(-) create mode 100644 sklearn/metrics/_plot/tests/test_display_common.py create mode 100644 sklearn/metrics/_plot/tests/test_precision_recall_display.py diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index d71addd0465f4..f63a9bcfdc6bb 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -4,6 +4,7 @@ from .. import average_precision_score from .. import precision_recall_curve from .._base import _check_pos_label_consistency +from .._classification import check_consistent_length, unique_labels from ...utils import check_matplotlib_support, deprecated @@ -228,7 +229,7 @@ def from_estimator( method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) if not is_classifier(estimator): - raise ValueError(f"{method_name} requires a classifier") + raise ValueError(f"{method_name} only supports classifiers") y_pred, pos_label = _get_response( X, estimator, @@ -317,6 +318,11 @@ def from_predictions( """ check_matplotlib_support(f"{cls.__name__}.from_predictions") + # for error consistency with other Displays used for binary + # classification, we need to make the following checks + check_consistent_length(y_true, y_pred, sample_weight) + # check for mixed types + unique_labels(y_true, y_pred) pos_label = _check_pos_label_consistency(pos_label, y_true) precision, recall, _ = precision_recall_curve( diff --git a/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py b/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py index 43d4171b42a05..5ff11e8d44c7d 100644 --- a/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py +++ b/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py @@ -12,7 +12,6 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC -from sklearn.svm import SVR from sklearn.metrics import ConfusionMatrixDisplay from sklearn.metrics import confusion_matrix @@ -25,32 +24,6 @@ ) -def test_confusion_matrix_display_validation(pyplot): - """Check that we raise the proper error when validating parameters.""" - X, y = make_classification( - n_samples=100, n_informative=5, n_classes=5, random_state=0 - ) - - regressor = SVR().fit(X, y) - y_pred_regressor = regressor.predict(X) - y_pred_classifier = SVC().fit(X, y).predict(X) - - err_msg = "ConfusionMatrixDisplay.from_estimator only supports classifiers" - with pytest.raises(ValueError, match=err_msg): - ConfusionMatrixDisplay.from_estimator(regressor, X, y) - - err_msg = "Mix type of y not allowed, got types" - with pytest.raises(ValueError, match=err_msg): - # Force `y_true` to be seen as a regression problem - ConfusionMatrixDisplay.from_predictions(y + 0.5, y_pred_classifier) - with pytest.raises(ValueError, match=err_msg): - ConfusionMatrixDisplay.from_predictions(y, y_pred_regressor) - - err_msg = "Found input variables with inconsistent numbers of samples" - with pytest.raises(ValueError, match=err_msg): - ConfusionMatrixDisplay.from_predictions(y, y_pred_classifier[::2]) - - @pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) def test_confusion_matrix_display_invalid_option(pyplot, constructor_name): """Check the error raise if an invalid parameter value is passed.""" diff --git a/sklearn/metrics/_plot/tests/test_display_common.py b/sklearn/metrics/_plot/tests/test_display_common.py new file mode 100644 index 0000000000000..21f7dfb0f2687 --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_display_common.py @@ -0,0 +1,34 @@ +import pytest + +from sklearn.datasets import make_classification +from sklearn.svm import SVC, SVR + +from sklearn.metrics import ConfusionMatrixDisplay, PrecisionRecallDisplay + + +@pytest.mark.parametrize("Display", [ConfusionMatrixDisplay, PrecisionRecallDisplay]) +def test_confusion_matrix_display_validation(pyplot, Display): + """Check that we raise the proper error when validating parameters for + display handling binary classification.""" + X, y = make_classification( + n_samples=100, n_informative=5, n_classes=5, random_state=0 + ) + + regressor = SVR().fit(X, y) + y_pred_regressor = regressor.predict(X) + y_pred_classifier = SVC().fit(X, y).predict(X) + + err_msg = f"{Display.__name__}.from_estimator only supports classifiers" + with pytest.raises(ValueError, match=err_msg): + Display.from_estimator(regressor, X, y) + + err_msg = "Mix type of y not allowed, got types" + with pytest.raises(ValueError, match=err_msg): + # Force `y_true` to be seen as a regression problem + Display.from_predictions(y + 0.5, y_pred_classifier) + with pytest.raises(ValueError, match=err_msg): + Display.from_predictions(y, y_pred_regressor) + + err_msg = "Found input variables with inconsistent numbers of samples" + with pytest.raises(ValueError, match=err_msg): + Display.from_predictions(y, y_pred_classifier[::2]) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index a4bd0267ac695..3f2c44a84c5f3 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -18,16 +18,15 @@ from sklearn.utils import shuffle from sklearn.compose import make_column_transformer -# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved pytestmark = pytest.mark.filterwarnings( + # TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved "ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:" - "matplotlib.*" + "matplotlib.*", + # TODO: Remove in 1.2 (as well as all the tests below) + "ignore:Function plot_precision_recall_curve is deprecated", ) -@pytest.mark.filterwarnings( - "ignore: Function `plot_precision_recall_curve` is deprecated" -) def test_errors(pyplot): X, y_multiclass = make_classification( n_classes=3, n_samples=50, n_informative=3, random_state=0 diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py new file mode 100644 index 0000000000000..b874349c30993 --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -0,0 +1,24 @@ +import pytest + +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression + +from sklearn.metrics import plot_precision_recall_curve + +# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved +pytestmark = pytest.mark.filterwarnings( + "ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:" + "matplotlib.*" +) + + +# FIXME: Remove in 1.2 +def test_plot_precision_recall_curve_deprecation(pyplot): + """Check that we raise a FutureWarning when calling + `plot_precision_recall_curve`.""" + + X, y = make_classification(random_state=0) + clf = LogisticRegression().fit(X, y) + deprecation_warning = "Function plot_precision_recall_curve is deprecated" + with pytest.warns(FutureWarning, match=deprecation_warning): + plot_precision_recall_curve(clf, X, y) From b57508febcc0a0c4edb168276791bb30ac22cdba Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 14:24:28 +0200 Subject: [PATCH 04/25] iter --- .../metrics/_plot/precision_recall_curve.py | 9 +- .../tests/test_confusion_matrix_display.py | 31 ++++- .../_plot/tests/test_display_common.py | 34 ------ .../tests/test_precision_recall_display.py | 107 +++++++++++++++++- 4 files changed, 138 insertions(+), 43 deletions(-) delete mode 100644 sklearn/metrics/_plot/tests/test_display_common.py diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index f63a9bcfdc6bb..1ae34dca4efd9 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -4,7 +4,7 @@ from .. import average_precision_score from .. import precision_recall_curve from .._base import _check_pos_label_consistency -from .._classification import check_consistent_length, unique_labels +from .._classification import check_consistent_length from ...utils import check_matplotlib_support, deprecated @@ -282,8 +282,7 @@ def from_predictions( pos_label : str or int, default=None The class considered as the positive class when computing the - precision and recall metrics. By default, `estimators.classes_[1]` - is considered as the positive class. + precision and recall metrics. **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -318,11 +317,7 @@ def from_predictions( """ check_matplotlib_support(f"{cls.__name__}.from_predictions") - # for error consistency with other Displays used for binary - # classification, we need to make the following checks check_consistent_length(y_true, y_pred, sample_weight) - # check for mixed types - unique_labels(y_true, y_pred) pos_label = _check_pos_label_consistency(pos_label, y_true) precision, recall, _ = precision_recall_curve( diff --git a/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py b/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py index 5ff11e8d44c7d..8db971fb26971 100644 --- a/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py +++ b/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py @@ -11,7 +11,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from sklearn.svm import SVC +from sklearn.svm import SVC, SVR from sklearn.metrics import ConfusionMatrixDisplay from sklearn.metrics import confusion_matrix @@ -24,6 +24,35 @@ ) +def test_confusion_matrix_display_validation(pyplot): + """Check that we raise the proper error when validating parameters.""" + X, y = make_classification( + n_samples=100, n_informative=5, n_classes=5, random_state=0 + ) + + with pytest.raises(NotFittedError): + ConfusionMatrixDisplay.from_estimator(SVC(), X, y) + + regressor = SVR().fit(X, y) + y_pred_regressor = regressor.predict(X) + y_pred_classifier = SVC().fit(X, y).predict(X) + + err_msg = "ConfusionMatrixDisplay.from_estimator only supports classifiers" + with pytest.raises(ValueError, match=err_msg): + ConfusionMatrixDisplay.from_estimator(regressor, X, y) + + err_msg = "Mix type of y not allowed, got types" + with pytest.raises(ValueError, match=err_msg): + # Force `y_true` to be seen as a regression problem + ConfusionMatrixDisplay.from_predictions(y + 0.5, y_pred_classifier) + with pytest.raises(ValueError, match=err_msg): + ConfusionMatrixDisplay.from_predictions(y, y_pred_regressor) + + err_msg = "Found input variables with inconsistent numbers of samples" + with pytest.raises(ValueError, match=err_msg): + ConfusionMatrixDisplay.from_predictions(y, y_pred_classifier[::2]) + + @pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) def test_confusion_matrix_display_invalid_option(pyplot, constructor_name): """Check the error raise if an invalid parameter value is passed.""" diff --git a/sklearn/metrics/_plot/tests/test_display_common.py b/sklearn/metrics/_plot/tests/test_display_common.py deleted file mode 100644 index 21f7dfb0f2687..0000000000000 --- a/sklearn/metrics/_plot/tests/test_display_common.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest - -from sklearn.datasets import make_classification -from sklearn.svm import SVC, SVR - -from sklearn.metrics import ConfusionMatrixDisplay, PrecisionRecallDisplay - - -@pytest.mark.parametrize("Display", [ConfusionMatrixDisplay, PrecisionRecallDisplay]) -def test_confusion_matrix_display_validation(pyplot, Display): - """Check that we raise the proper error when validating parameters for - display handling binary classification.""" - X, y = make_classification( - n_samples=100, n_informative=5, n_classes=5, random_state=0 - ) - - regressor = SVR().fit(X, y) - y_pred_regressor = regressor.predict(X) - y_pred_classifier = SVC().fit(X, y).predict(X) - - err_msg = f"{Display.__name__}.from_estimator only supports classifiers" - with pytest.raises(ValueError, match=err_msg): - Display.from_estimator(regressor, X, y) - - err_msg = "Mix type of y not allowed, got types" - with pytest.raises(ValueError, match=err_msg): - # Force `y_true` to be seen as a regression problem - Display.from_predictions(y + 0.5, y_pred_classifier) - with pytest.raises(ValueError, match=err_msg): - Display.from_predictions(y, y_pred_regressor) - - err_msg = "Found input variables with inconsistent numbers of samples" - with pytest.raises(ValueError, match=err_msg): - Display.from_predictions(y, y_pred_classifier[::2]) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index b874349c30993..b28c94db1dbc3 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -1,9 +1,12 @@ import pytest +from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.datasets import make_classification +from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression +from sklearn.svm import SVC, SVR -from sklearn.metrics import plot_precision_recall_curve +from sklearn.metrics import PrecisionRecallDisplay, plot_precision_recall_curve # TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved pytestmark = pytest.mark.filterwarnings( @@ -12,6 +15,48 @@ ) +def test_confusion_matrix_display_validation(pyplot): + """Check that we raise the proper error when validating parameters.""" + X, y = make_classification( + n_samples=100, n_informative=5, n_classes=5, random_state=0 + ) + + with pytest.raises(NotFittedError): + PrecisionRecallDisplay.from_estimator(SVC(), X, y) + + regressor = SVR().fit(X, y) + y_pred_regressor = regressor.predict(X) + classifier = SVC(probability=True).fit(X, y) + y_pred_classifier = classifier.predict_proba(X)[:, -1] + + err_msg = "PrecisionRecallDisplay.from_estimator only supports classifiers" + with pytest.raises(ValueError, match=err_msg): + PrecisionRecallDisplay.from_estimator(regressor, X, y) + + err_msg = "SVC should be a binary classifier" + with pytest.raises(ValueError, match=err_msg): + PrecisionRecallDisplay.from_estimator(classifier, X, y) + + err_msg = "{} format is not supported" + with pytest.raises(ValueError, match=err_msg.format("continuous")): + # Force `y_true` to be seen as a regression problem + PrecisionRecallDisplay.from_predictions(y + 0.5, y_pred_classifier, pos_label=1) + with pytest.raises(ValueError, match=err_msg.format("multiclass")): + PrecisionRecallDisplay.from_predictions(y, y_pred_regressor, pos_label=1) + + err_msg = "Found input variables with inconsistent numbers of samples" + with pytest.raises(ValueError, match=err_msg): + PrecisionRecallDisplay.from_predictions(y, y_pred_classifier[::2]) + + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + y += 10 + classifier.fit(X, y) + y_pred_classifier = classifier.predict_proba(X)[:, -1] + err_msg = r"y_true takes value in {10, 11} and pos_label is not specified" + with pytest.raises(ValueError, match=err_msg): + PrecisionRecallDisplay.from_predictions(y, y_pred_classifier) + + # FIXME: Remove in 1.2 def test_plot_precision_recall_curve_deprecation(pyplot): """Check that we raise a FutureWarning when calling @@ -22,3 +67,63 @@ def test_plot_precision_recall_curve_deprecation(pyplot): deprecation_warning = "Function plot_precision_recall_curve is deprecated" with pytest.warns(FutureWarning, match=deprecation_warning): plot_precision_recall_curve(clf, X, y) + + +@pytest.mark.parametrize( + "response_method, msg", + [ + ( + "predict_proba", + "response method predict_proba is not defined in MyClassifier", + ), + ( + "decision_function", + "response method decision_function is not defined in MyClassifier", + ), + ( + "auto", + "response method decision_function or predict_proba is not " + "defined in MyClassifier", + ), + ( + "bad_method", + "response_method must be 'predict_proba', 'decision_function' or 'auto'", + ), + ], +) +def test_precision_recall_display_bad_response(pyplot, response_method, msg): + """Check that the proper error is raised when passing a `response_method` + not compatible with the estimator.""" + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + + class MyClassifier(ClassifierMixin, BaseEstimator): + def fit(self, X, y): + self.fitted_ = True + self.classes_ = [0, 1] + return self + + clf = MyClassifier().fit(X, y) + + with pytest.raises(ValueError, match=msg): + PrecisionRecallDisplay.from_estimator( + clf, X, y, response_method=response_method + ) + + +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +def test_precision_recall_display_plotting(pyplot, constructor_name, response_method): + """Check the overall plotting rendering.""" + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + classifier = LogisticRegression().fit(X, y) + classifier.fit(X, y) + + y_pred = getattr(classifier, response_method)(X) + + # safe guard for the binary if/else construction + assert constructor_name in ("from_estimator", "from_predictions") + + if constructor_name == "from_estimator": + PrecisionRecallDisplay.from_estimator(classifier, X, y) + else: + PrecisionRecallDisplay.from_predictions(y, y_pred) From bc242046086fc85b4a2a66976affe6d5de7f2524 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 15:32:11 +0200 Subject: [PATCH 05/25] more tests --- .../tests/test_precision_recall_display.py | 75 ++++++++++++++++++- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index b28c94db1dbc3..7703b9e9366c8 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -1,9 +1,11 @@ +import numpy as np import pytest from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.datasets import make_classification from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression +from sklearn.metrics import average_precision_score, precision_recall_curve from sklearn.svm import SVC, SVR from sklearn.metrics import PrecisionRecallDisplay, plot_precision_recall_curve @@ -115,15 +117,84 @@ def fit(self, X, y): def test_precision_recall_display_plotting(pyplot, constructor_name, response_method): """Check the overall plotting rendering.""" X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + pos_label = 1 + classifier = LogisticRegression().fit(X, y) classifier.fit(X, y) y_pred = getattr(classifier, response_method)(X) + y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, pos_label] # safe guard for the binary if/else construction assert constructor_name in ("from_estimator", "from_predictions") if constructor_name == "from_estimator": - PrecisionRecallDisplay.from_estimator(classifier, X, y) + display = PrecisionRecallDisplay.from_estimator( + classifier, X, y, response_method=response_method + ) else: - PrecisionRecallDisplay.from_predictions(y, y_pred) + display = PrecisionRecallDisplay.from_predictions( + y, y_pred, pos_label=pos_label + ) + + precision, recall, _ = precision_recall_curve(y, y_pred, pos_label=pos_label) + average_precision = average_precision_score(y, y_pred, pos_label=pos_label) + + np.testing.assert_allclose(display.precision, precision) + np.testing.assert_allclose(display.recall, recall) + assert display.average_precision == pytest.approx(average_precision) + + import matplotlib as mpl + + assert isinstance(display.line_, mpl.lines.Line2D) + assert isinstance(display.ax_, mpl.axes.Axes) + assert isinstance(display.figure_, mpl.figure.Figure) + + assert display.ax_.get_xlabel() == "Recall (Positive label: 1)" + assert display.ax_.get_ylabel() == "Precision (Positive label: 1)" + + # plotting passing some new parameters + display.plot(alpha=0.8, name="MySpecialEstimator") + expected_label = f"MySpecialEstimator (AP = {average_precision:0.2f})" + assert display.line_.get_label() == expected_label + assert display.line_.get_alpha() == pytest.approx(0.8) + + +@pytest.mark.parametrize( + "constructor_name, default_label", + [ + ("from_estimator", "LogisticRegression (AP = {:.2f})"), + ("from_predictions", "AP = {:.2f}"), + ], +) +def test_precision_recall_display_name(pyplot, constructor_name, default_label): + """Check the behaviour of the name parameters""" + X, y = make_classification(n_classes=2, n_samples=100, random_state=0) + pos_label = 1 + + classifier = LogisticRegression().fit(X, y) + classifier.fit(X, y) + + y_pred = classifier.predict_proba(X)[:, pos_label] + + # safe guard for the binary if/else construction + assert constructor_name in ("from_estimator", "from_predictions") + + if constructor_name == "from_estimator": + display = PrecisionRecallDisplay.from_estimator(classifier, X, y) + else: + display = PrecisionRecallDisplay.from_predictions( + y, y_pred, pos_label=pos_label + ) + + average_precision = average_precision_score(y, y_pred, pos_label=pos_label) + + # check that the default name is used + assert display.line_.get_label() == default_label.format(average_precision) + + # check that the name can be set + display.plot(name="MySpecialEstimator") + assert ( + display.line_.get_label() + == f"MySpecialEstimator (AP = {average_precision:.2f})" + ) From a74c7b11c6365c3ab01e02b3c55d47621c2096a6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 15:35:24 +0200 Subject: [PATCH 06/25] more tests --- .../tests/test_precision_recall_display.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 7703b9e9366c8..23b17d45f67e5 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -2,10 +2,13 @@ import pytest from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.compose import make_column_transformer from sklearn.datasets import make_classification from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression from sklearn.metrics import average_precision_score, precision_recall_curve +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC, SVR from sklearn.metrics import PrecisionRecallDisplay, plot_precision_recall_curve @@ -198,3 +201,21 @@ def test_precision_recall_display_name(pyplot, constructor_name, default_label): display.line_.get_label() == f"MySpecialEstimator (AP = {average_precision:.2f})" ) + + +@pytest.mark.parametrize( + "clf", + [ + make_pipeline(StandardScaler(), LogisticRegression()), + make_pipeline( + make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression() + ), + ], +) +def test_precision_recall_display_pipeline(pyplot, clf): + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + with pytest.raises(NotFittedError): + PrecisionRecallDisplay.from_estimator(clf, X, y) + clf.fit(X, y) + display = PrecisionRecallDisplay.from_estimator(clf, X, y) + assert display.estimator_name == clf.__class__.__name__ From abfdb1a6c0ef58aa80d17d9d87e23df079cb205b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 16:01:51 +0200 Subject: [PATCH 07/25] TST check strings --- .../tests/test_precision_recall_display.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 23b17d45f67e5..d633cb4abba50 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -3,7 +3,7 @@ from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.compose import make_column_transformer -from sklearn.datasets import make_classification +from sklearn.datasets import load_breast_cancer, make_classification from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression from sklearn.metrics import average_precision_score, precision_recall_curve @@ -219,3 +219,30 @@ def test_precision_recall_display_pipeline(pyplot, clf): clf.fit(X, y) display = PrecisionRecallDisplay.from_estimator(clf, X, y) assert display.estimator_name == clf.__class__.__name__ + + +def test_precision_recall_display_string_labels(pyplot): + # regression test #15738 + cancer = load_breast_cancer() + X, y = cancer.data, cancer.target_names[cancer.target] + + lr = make_pipeline(StandardScaler(), LogisticRegression()) + lr.fit(X, y) + for klass in cancer.target_names: + assert klass in lr.classes_ + display = PrecisionRecallDisplay.from_estimator(lr, X, y) + + y_pred = lr.predict_proba(X)[:, 1] + avg_prec = average_precision_score(y, y_pred, pos_label=lr.classes_[1]) + + assert display.average_precision == pytest.approx(avg_prec) + assert display.estimator_name == lr.__class__.__name__ + + err_msg = r"y_true takes value in {'benign', 'malignant'}" + with pytest.raises(ValueError, match=err_msg): + PrecisionRecallDisplay.from_predictions(y, y_pred) + + display = PrecisionRecallDisplay.from_predictions( + y, y_pred, pos_label=lr.classes_[1] + ) + assert display.average_precision == pytest.approx(avg_prec) From 1c154ff2b06e698045caab425a158ee715ef2335 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 16:15:50 +0200 Subject: [PATCH 08/25] TST add test to check average precision computed --- .../_plot/tests/test_plot_precision_recall.py | 19 ---- .../tests/test_precision_recall_display.py | 96 +++++++++++++++++++ 2 files changed, 96 insertions(+), 19 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 3f2c44a84c5f3..81d18bb2ddfff 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -4,7 +4,6 @@ from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.metrics import plot_precision_recall_curve -from sklearn.metrics import PrecisionRecallDisplay from sklearn.metrics import average_precision_score from sklearn.metrics import precision_recall_curve from sklearn.datasets import make_classification @@ -197,24 +196,6 @@ def test_plot_precision_recall_curve_estimator_name_multiple_calls(pyplot): assert clf_name in disp.line_.get_label() -@pytest.mark.parametrize( - "average_precision, estimator_name, expected_label", - [ - (0.9, None, "AP = 0.90"), - (None, "my_est", "my_est"), - (0.8, "my_est2", "my_est2 (AP = 0.80)"), - ], -) -def test_default_labels(pyplot, average_precision, estimator_name, expected_label): - prec = np.array([1, 0.5, 0]) - recall = np.array([0, 0.5, 1]) - disp = PrecisionRecallDisplay( - prec, recall, average_precision=average_precision, estimator_name=estimator_name - ) - disp.plot() - assert disp.line_.get_label() == expected_label - - @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) def test_plot_precision_recall_pos_label(pyplot, response_method): # check that we can provide the positive label and display the proper diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index d633cb4abba50..db4be93acc616 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -7,9 +7,11 @@ from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression from sklearn.metrics import average_precision_score, precision_recall_curve +from sklearn.model_selection import train_test_split from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC, SVR +from sklearn.utils import shuffle from sklearn.metrics import PrecisionRecallDisplay, plot_precision_recall_curve @@ -246,3 +248,97 @@ def test_precision_recall_display_string_labels(pyplot): y, y_pred, pos_label=lr.classes_[1] ) assert display.average_precision == pytest.approx(avg_prec) + + +@pytest.mark.parametrize( + "average_precision, estimator_name, expected_label", + [ + (0.9, None, "AP = 0.90"), + (None, "my_est", "my_est"), + (0.8, "my_est2", "my_est2 (AP = 0.80)"), + ], +) +def test_default_labels(pyplot, average_precision, estimator_name, expected_label): + """Check the default labels used in the display.""" + precision = np.array([1, 0.5, 0]) + recall = np.array([0, 0.5, 1]) + display = PrecisionRecallDisplay( + precision, + recall, + average_precision=average_precision, + estimator_name=estimator_name, + ) + display.plot() + assert display.line_.get_label() == expected_label + + +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_method): + # check that we can provide the positive label and display the proper + # statistics + X, y = load_breast_cancer(return_X_y=True) + # create an highly imbalanced version of the breast cancer dataset + idx_positive = np.flatnonzero(y == 1) + idx_negative = np.flatnonzero(y == 0) + idx_selected = np.hstack([idx_negative, idx_positive[:25]]) + X, y = X[idx_selected], y[idx_selected] + X, y = shuffle(X, y, random_state=42) + # only use 2 features to make the problem even harder + X = X[:, :2] + y = np.array(["cancer" if c == 1 else "not cancer" for c in y], dtype=object) + X_train, X_test, y_train, y_test = train_test_split( + X, + y, + stratify=y, + random_state=0, + ) + + classifier = LogisticRegression() + classifier.fit(X_train, y_train) + + # sanity check to be sure the positive class is classes_[0] and that we + # are betrayed by the class imbalance + assert classifier.classes_.tolist() == ["cancer", "not cancer"] + + y_pred = getattr(classifier, response_method)(X_test) + y_pred_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 0] + y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1] + + if constructor_name == "from_estimator": + display = PrecisionRecallDisplay.from_estimator( + classifier, + X_test, + y_test, + pos_label="cancer", + response_method=response_method, + ) + else: + display = PrecisionRecallDisplay.from_predictions( + y_test, + y_pred_cancer, + pos_label="cancer", + ) + # we should obtain the statistics of the "cancer" class + avg_prec_limit = 0.65 + assert display.average_precision < avg_prec_limit + assert -np.trapz(display.precision, display.recall) < avg_prec_limit + + # otherwise we should obtain the statistics of the "not cancer" class + if constructor_name == "from_estimator": + display = PrecisionRecallDisplay.from_estimator( + classifier, + X_test, + y_test, + response_method=response_method, + pos_label="not cancer", + ) + else: + display = PrecisionRecallDisplay.from_predictions( + y_test, + y_pred_not_cancer, + pos_label="not cancer", + ) + avg_prec_limit = 0.95 + assert display.average_precision > avg_prec_limit + assert -np.trapz(display.precision, display.recall) > avg_prec_limit From 73a6ab063dc90cc53298e09792cb225a775880b8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 16:22:07 +0200 Subject: [PATCH 09/25] iter --- doc/whats_new/v1.0.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index b0079e9fe527e..b8197e7f02de7 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -447,6 +447,14 @@ Changelog class methods and will be removed in 1.2. :pr:`18543` by `Guillaume Lemaitre`_. + - |API| :class:`metrics.PrecisionRecallDisplay` exposes two class methods + :func:`~metrics.PrecisionRecallDisplay.from_estimator` and + :func:`~metrics.PrecisionRecallDisplay.from_predictions` allowing to create + a precision-recall curve using an estimator or the predictions. + :func:`metrics.plot_precision_recall_curve` is deprecated in favor of these + two class methods and will be removed in 1.2. + :pr:`20552` by `Guillaume Lemaitre`_. + - |Enhancement| A fix to raise an error in :func:`metrics.hinge_loss` when ``pred_decision`` is 1d whereas it is a multiclass classification or when ``pred_decision`` parameter is not consistent with the ``labels`` parameter. From 1529e6645cbf36b17e056097361e5927d9e3c75c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 16:29:48 +0200 Subject: [PATCH 10/25] DOC some update --- sklearn/metrics/_plot/precision_recall_curve.py | 17 +++++++++++------ sklearn/metrics/_ranking.py | 7 ++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 1ae34dca4efd9..7ac9954005cdb 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -12,17 +12,20 @@ class PrecisionRecallDisplay: """Precision Recall visualization. - It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve` - to create a visualizer. All parameters are stored as attributes. + It is recommend to use + :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` or + :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` to create + a :class:`~sklearn.metrics.PredictionRecallDisplay`. All parameters are + stored as attributes. Read more in the :ref:`User Guide `. Parameters ----------- - precision : ndarray + precision : ndarray of shape (n_samples,) Precision values. - recall : ndarray + recall : ndarray of shape (n_samples,) Recall values. average_precision : float, default=None @@ -52,8 +55,10 @@ class PrecisionRecallDisplay: -------- precision_recall_curve : Compute precision-recall pairs for different probability thresholds. - plot_precision_recall_curve : Plot Precision Recall Curve for binary - classifiers. + PrecisionRecallDisplay.from_estimator : Plot Precision Recall Curve given + a binary classifier. + PrecisionRecallDisplay.from_predictions : Plot Precision Recall Curve + using predictions from a binary classifier. Examples -------- diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 603d7c4d5be56..081a6ee9fa477 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -823,9 +823,10 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight See Also -------- - plot_precision_recall_curve : Plot Precision Recall Curve for binary - classifiers. - PrecisionRecallDisplay : Precision Recall visualization. + PrecisionRecallDisplay.from_estimator : Plot Precision Recall Curve given + a binary classifier. + PrecisionRecallDisplay.from_predictions : Plot Precision Recall Curve + using predictions from a binary classifier. average_precision_score : Compute average precision from prediction scores. det_curve: Compute error rates for different probability thresholds. roc_curve : Compute Receiver operating characteristic (ROC) curve. From 6679579522861c316e40182cbf0501ce1c7f7a2a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 17:00:55 +0200 Subject: [PATCH 11/25] DOC some update --- .../model_selection/plot_precision_recall.py | 115 +++++++++--------- 1 file changed, 55 insertions(+), 60 deletions(-) diff --git a/examples/model_selection/plot_precision_recall.py b/examples/model_selection/plot_precision_recall.py index 83493c44c7847..c95c9dbb254a9 100644 --- a/examples/model_selection/plot_precision_recall.py +++ b/examples/model_selection/plot_precision_recall.py @@ -112,9 +112,9 @@ X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] # Limit to the two first classes, and split into training and test -X_train, X_test, y_train, y_test = train_test_split(X[y < 2], y[y < 2], - test_size=.5, - random_state=random_state) +X_train, X_test, y_train, y_test = train_test_split( + X[y < 2], y[y < 2], test_size=0.5, random_state=random_state +) # Create a simple classifier classifier = svm.LinearSVC(random_state=random_state) @@ -125,21 +125,18 @@ # Compute the average precision score # ................................... from sklearn.metrics import average_precision_score + average_precision = average_precision_score(y_test, y_score) -print('Average precision-recall score: {0:0.2f}'.format( - average_precision)) +print("Average precision-recall score: {0:0.2f}".format(average_precision)) # %% # Plot the Precision-Recall curve # ................................ -from sklearn.metrics import precision_recall_curve -from sklearn.metrics import plot_precision_recall_curve -import matplotlib.pyplot as plt +from sklearn.metrics import PrecisionRecallDisplay -disp = plot_precision_recall_curve(classifier, X_test, y_test) -disp.ax_.set_title('2-class Precision-Recall curve: ' - 'AP={0:0.2f}'.format(average_precision)) +display = PrecisionRecallDisplay.from_estimator(classifier, X_test, y_test) +display.ax_.set_title("2-class Precision-Recall curve") # %% # In multi-label settings @@ -158,8 +155,9 @@ n_classes = Y.shape[1] # Split into training and test -X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.5, - random_state=random_state) +X_train, X_test, Y_train, Y_test = train_test_split( + X, Y, test_size=0.5, random_state=random_state +) # We use OneVsRestClassifier for multi-label prediction from sklearn.multiclass import OneVsRestClassifier @@ -173,6 +171,7 @@ # %% # The average precision score in multi-label settings # .................................................... +from sklearn.metrics import precision_recall_curve from sklearn.metrics import average_precision_score # For each class @@ -180,73 +179,69 @@ recall = dict() average_precision = dict() for i in range(n_classes): - precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], - y_score[:, i]) + precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], y_score[:, i]) average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i]) # A "micro-average": quantifying score on all classes jointly -precision["micro"], recall["micro"], _ = precision_recall_curve(Y_test.ravel(), - y_score.ravel()) -average_precision["micro"] = average_precision_score(Y_test, y_score, - average="micro") -print('Average precision score, micro-averaged over all classes: {0:0.2f}' - .format(average_precision["micro"])) +precision["micro"], recall["micro"], _ = precision_recall_curve( + Y_test.ravel(), y_score.ravel() +) +average_precision["micro"] = average_precision_score(Y_test, y_score, average="micro") # %% # Plot the micro-averaged Precision-Recall curve # ............................................... -# - -plt.figure() -plt.step(recall['micro'], precision['micro'], where='post') - -plt.xlabel('Recall') -plt.ylabel('Precision') -plt.ylim([0.0, 1.05]) -plt.xlim([0.0, 1.0]) -plt.title( - 'Average precision score, micro-averaged over all classes: AP={0:0.2f}' - .format(average_precision["micro"])) +display = PrecisionRecallDisplay( + recall=recall["micro"], + precision=precision["micro"], + average_precision=average_precision["micro"], +) +display.plot() +_ = display.ax_.set_title("Micro-averaged over all classes") # %% # Plot Precision-Recall curve for each class and iso-f1 curves # ............................................................. -# +import matplotlib.pyplot as plt from itertools import cycle + # setup plot details -colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal']) +colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"]) + +_, ax = plt.subplots(figsize=(7, 8)) -plt.figure(figsize=(7, 8)) f_scores = np.linspace(0.2, 0.8, num=4) -lines = [] -labels = [] +lines, labels = [], [] for f_score in f_scores: x = np.linspace(0.01, 1) y = f_score * x / (2 * x - f_score) - l, = plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2) - plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02)) + (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2) + plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02)) -lines.append(l) -labels.append('iso-f1 curves') -l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=2) -lines.append(l) -labels.append('micro-average Precision-recall (area = {0:0.2f})' - ''.format(average_precision["micro"])) +display = PrecisionRecallDisplay( + recall=recall["micro"], + precision=precision["micro"], + average_precision=average_precision["micro"], +) +display.plot(ax=ax, name="Micro-average precision-recall", color="gold") for i, color in zip(range(n_classes), colors): - l, = plt.plot(recall[i], precision[i], color=color, lw=2) - lines.append(l) - labels.append('Precision-recall for class {0} (area = {1:0.2f})' - ''.format(i, average_precision[i])) - -fig = plt.gcf() -fig.subplots_adjust(bottom=0.25) -plt.xlim([0.0, 1.0]) -plt.ylim([0.0, 1.05]) -plt.xlabel('Recall') -plt.ylabel('Precision') -plt.title('Extension of Precision-Recall curve to multi-class') -plt.legend(lines, labels, loc=(0, -.38), prop=dict(size=14)) - + display = PrecisionRecallDisplay( + recall=recall[i], + precision=precision[i], + average_precision=average_precision[i], + ) + display.plot(ax=ax, name=f"Precision-recall for class {i}", color=color) + +handles, labels = display.ax_.get_legend_handles_labels() +handles.extend([l]) +labels.extend(["iso-f1 curves"]) + +ax.set_xlim([0.0, 1.0]) +ax.set_ylim([0.0, 1.05]) +ax.legend(handles=handles, labels=labels, loc="best") +ax.set_title("Extension of Precision-Recall curve to multi-class") plt.show() + +# %% From 9aa4af8401ee366ea4e917d546c61f5635d5e5be Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 17:04:58 +0200 Subject: [PATCH 12/25] iter --- doc/whats_new/v1.0.rst | 2 +- examples/model_selection/plot_precision_recall.py | 12 ++---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index b8197e7f02de7..29d483082db54 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -447,7 +447,7 @@ Changelog class methods and will be removed in 1.2. :pr:`18543` by `Guillaume Lemaitre`_. - - |API| :class:`metrics.PrecisionRecallDisplay` exposes two class methods +- |API| :class:`metrics.PrecisionRecallDisplay` exposes two class methods :func:`~metrics.PrecisionRecallDisplay.from_estimator` and :func:`~metrics.PrecisionRecallDisplay.from_predictions` allowing to create a precision-recall curve using an estimator or the predictions. diff --git a/examples/model_selection/plot_precision_recall.py b/examples/model_selection/plot_precision_recall.py index c95c9dbb254a9..47a82da442e41 100644 --- a/examples/model_selection/plot_precision_recall.py +++ b/examples/model_selection/plot_precision_recall.py @@ -121,15 +121,6 @@ classifier.fit(X_train, y_train) y_score = classifier.decision_function(X_test) -# %% -# Compute the average precision score -# ................................... -from sklearn.metrics import average_precision_score - -average_precision = average_precision_score(y_test, y_score) - -print("Average precision-recall score: {0:0.2f}".format(average_precision)) - # %% # Plot the Precision-Recall curve # ................................ @@ -233,10 +224,11 @@ ) display.plot(ax=ax, name=f"Precision-recall for class {i}", color=color) +# add the legend for the iso-f1 curves handles, labels = display.ax_.get_legend_handles_labels() handles.extend([l]) labels.extend(["iso-f1 curves"]) - +# set the legend and the axes ax.set_xlim([0.0, 1.0]) ax.set_ylim([0.0, 1.05]) ax.legend(handles=handles, labels=labels, loc="best") From 2c9f6b1ecd8c316e09bad2ea213c2f4ef7b62e74 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 17:22:17 +0200 Subject: [PATCH 13/25] iter --- .../model_selection/plot_precision_recall.py | 74 +++++++++++++------ 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/examples/model_selection/plot_precision_recall.py b/examples/model_selection/plot_precision_recall.py index 47a82da442e41..ff36271228377 100644 --- a/examples/model_selection/plot_precision_recall.py +++ b/examples/model_selection/plot_precision_recall.py @@ -92,52 +92,81 @@ """ # %% # In binary classification settings -# -------------------------------------------------------- +# --------------------------------- # -# Create simple data -# .................. +# Dataset and model +# ................. # -# Try to differentiate the two first classes of the iris data -from sklearn import svm, datasets -from sklearn.model_selection import train_test_split +# We will try to create a linear model to differentiate two types of irises. +# We will use a Linear SVC classifier. import numpy as np +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split -iris = datasets.load_iris() -X = iris.data -y = iris.target +X, y = load_iris(return_X_y=True) # Add noisy features random_state = np.random.RandomState(0) n_samples, n_features = X.shape -X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] +X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1) # Limit to the two first classes, and split into training and test X_train, X_test, y_train, y_test = train_test_split( X[y < 2], y[y < 2], test_size=0.5, random_state=random_state ) -# Create a simple classifier -classifier = svm.LinearSVC(random_state=random_state) +# %% +# Linear SVC will expect each feature to have a similar range of values. Thus, +# we will first scale the data using a +# :class:`~sklearn.preprocessing.StandardScaler`. +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.svm import LinearSVC + +classifier = make_pipeline(StandardScaler(), LinearSVC(random_state=random_state)) classifier.fit(X_train, y_train) -y_score = classifier.decision_function(X_test) # %% # Plot the Precision-Recall curve # ................................ +# +# To plot the precision-recall curve, you should use +# :class:`~sklearn.metrics.PrecisionRecallDisplay`. Indeed, there is two +# methods available depending if you already computed the predictions of the +# classifier or not. +# +# Let's first plot the precision-recall curve without the classifier +# predictions. Thus, we should use +# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` that will +# compute the predictions for us before to plot the curve. from sklearn.metrics import PrecisionRecallDisplay -display = PrecisionRecallDisplay.from_estimator(classifier, X_test, y_test) -display.ax_.set_title("2-class Precision-Recall curve") +display = PrecisionRecallDisplay.from_estimator( + classifier, X_test, y_test, name="LinearSVC" +) +_ = display.ax_.set_title("2-class Precision-Recall curve") + +# %% +# In the case, that we already got the estimated probabilities or scores for +# our model, then we can use +# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions`. +y_score = classifier.decision_function(X_test) + +display = PrecisionRecallDisplay.from_predictions(y_test, y_score, name="LinearSVC") +_ = display.ax_.set_title("2-class Precision-Recall curve") # %% # In multi-label settings # ------------------------ # +# The precision-recall curve does not support the multilabel setting. However, +# one can decide how to handle this case. We show such an example below. +# # Create multi-label data, fit, and predict -# ........................................... +# ......................................... # # We create a multi-label dataset, to illustrate the precision-recall in -# multi-label settings +# multi-label settings. from sklearn.preprocessing import label_binarize @@ -150,11 +179,14 @@ X, Y, test_size=0.5, random_state=random_state ) -# We use OneVsRestClassifier for multi-label prediction +# %% +# We use :class:`~sklearn.multiclass.OneVsRestClassifier` for multi-label +# prediction. from sklearn.multiclass import OneVsRestClassifier -# Run classifier -classifier = OneVsRestClassifier(svm.LinearSVC(random_state=random_state)) +classifier = OneVsRestClassifier( + make_pipeline(StandardScaler(), LinearSVC(random_state=random_state)) +) classifier.fit(X_train, Y_train) y_score = classifier.decision_function(X_test) @@ -235,5 +267,3 @@ ax.set_title("Extension of Precision-Recall curve to multi-class") plt.show() - -# %% From d7527cdcb6518488f382987ba4c42c7b58e10412 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 17:30:28 +0200 Subject: [PATCH 14/25] iter --- examples/model_selection/plot_precision_recall.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/model_selection/plot_precision_recall.py b/examples/model_selection/plot_precision_recall.py index ff36271228377..85f3d69f9396e 100644 --- a/examples/model_selection/plot_precision_recall.py +++ b/examples/model_selection/plot_precision_recall.py @@ -128,7 +128,7 @@ # %% # Plot the Precision-Recall curve -# ................................ +# ............................... # # To plot the precision-recall curve, you should use # :class:`~sklearn.metrics.PrecisionRecallDisplay`. Indeed, there is two @@ -157,7 +157,7 @@ # %% # In multi-label settings -# ------------------------ +# ----------------------- # # The precision-recall curve does not support the multilabel setting. However, # one can decide how to handle this case. We show such an example below. @@ -193,7 +193,7 @@ # %% # The average precision score in multi-label settings -# .................................................... +# ................................................... from sklearn.metrics import precision_recall_curve from sklearn.metrics import average_precision_score @@ -213,7 +213,7 @@ # %% # Plot the micro-averaged Precision-Recall curve -# ............................................... +# .............................................. display = PrecisionRecallDisplay( recall=recall["micro"], precision=precision["micro"], @@ -224,7 +224,7 @@ # %% # Plot Precision-Recall curve for each class and iso-f1 curves -# ............................................................. +# ............................................................ import matplotlib.pyplot as plt from itertools import cycle From 8ab59f023c1bfdcb280780a905110628e2b4b455 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 17:34:59 +0200 Subject: [PATCH 15/25] DOC update user guide --- doc/modules/model_evaluation.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 5453942cd1a13..a36c45bb06ca5 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -796,9 +796,10 @@ score: Note that the :func:`precision_recall_curve` function is restricted to the binary case. The :func:`average_precision_score` function works only in -binary classification and multilabel indicator format. The -:func:`plot_precision_recall_curve` function plots the precision recall as -follows. +binary classification and multilabel indicator format. +The :func:`PredictionRecallDisplay.from_estimator` and +:func:`PredictionRecallDisplay.from_predictions` functions will plot the +precision-recall curve as follows. .. image:: ../auto_examples/model_selection/images/sphx_glr_plot_precision_recall_001.png :target: ../auto_examples/model_selection/plot_precision_recall.html#plot-the-precision-recall-curve From 7240dd8f322368fe070d68b7cccec4eb8fc63767 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 18:47:49 +0200 Subject: [PATCH 16/25] FIX order parameters --- .../metrics/_plot/precision_recall_curve.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 7ac9954005cdb..f631bed27c3c2 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -160,9 +160,9 @@ def from_estimator( y, *, sample_weight=None, + pos_label=None, response_method="auto", name=None, - pos_label=None, ax=None, **kwargs, ): @@ -183,6 +183,11 @@ def from_estimator( sample_weight : array-like of shape (n_samples,), default=None Sample weights. + pos_label : str or int, default=None + The class considered as the positive class when computing the + precision and recall metrics. By default, `estimators.classes_[1]` + is considered as the positive class. + response_method : {'predict_proba', 'decision_function', 'auto'}, \ default='auto' Specifies whether to use :term:`predict_proba` or @@ -196,11 +201,6 @@ def from_estimator( ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. - pos_label : str or int, default=None - The class considered as the positive class when computing the - precision and recall metrics. By default, `estimators.classes_[1]` - is considered as the positive class. - **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -261,8 +261,8 @@ def from_predictions( y_pred, *, sample_weight=None, - name=None, pos_label=None, + name=None, ax=None, **kwargs, ): @@ -279,16 +279,16 @@ def from_predictions( sample_weight : array-like of shape (n_samples,), default=None Sample weights. + pos_label : str or int, default=None + The class considered as the positive class when computing the + precision and recall metrics. + name : str, default=None Name for labeling curve. If `None`, no name is used. ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. - pos_label : str or int, default=None - The class considered as the positive class when computing the - precision and recall metrics. - **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. From 1a592f9f2665de047585b10ac5a0ce237eb564fe Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 19:13:04 +0200 Subject: [PATCH 17/25] TST add common test for future curve --- .../_plot/tests/test_common_curve_display.py | 138 ++++++++++++++++++ .../tests/test_precision_recall_display.py | 42 ------ 2 files changed, 138 insertions(+), 42 deletions(-) create mode 100644 sklearn/metrics/_plot/tests/test_common_curve_display.py diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py new file mode 100644 index 0000000000000..158f1dafbb004 --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -0,0 +1,138 @@ +import pytest + +from sklearn.base import ClassifierMixin, clone +from sklearn.compose import make_column_transformer +from sklearn.datasets import load_iris +from sklearn.exceptions import NotFittedError +from sklearn.linear_model import LogisticRegression +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.tree import DecisionTreeClassifier + +from sklearn.metrics import PrecisionRecallDisplay + + +@pytest.fixture(scope="module") +def data(): + return load_iris(return_X_y=True) + + +@pytest.fixture(scope="module") +def data_binary(data): + X, y = data + return X[y < 2], y[y < 2] + + +@pytest.mark.parametrize("Display", [PrecisionRecallDisplay]) +def test_display_curve_error_non_binary(pyplot, data, Display): + """Check that a proper error is raised when only binary classification is + supported.""" + X, y = data + clf = DecisionTreeClassifier().fit(X, y) + + msg = "DecisionTreeClassifier should be a binary classifier" + with pytest.raises(ValueError, match=msg): + Display.from_estimator(clf, X, y) + + +@pytest.mark.parametrize( + "response_method, msg", + [ + ( + "predict_proba", + "response method predict_proba is not defined in MyClassifier", + ), + ( + "decision_function", + "response method decision_function is not defined in MyClassifier", + ), + ( + "auto", + "response method decision_function or predict_proba is not " + "defined in MyClassifier", + ), + ( + "bad_method", + "response_method must be 'predict_proba', 'decision_function' or 'auto'", + ), + ], +) +@pytest.mark.parametrize("Display", [PrecisionRecallDisplay]) +def test_display_curve_error_no_response( + pyplot, + data_binary, + response_method, + msg, + Display, +): + """Check that a proper error is raised when the response method requested + is not defined for the given trained classifier.""" + X, y = data_binary + + class MyClassifier(ClassifierMixin): + def fit(self, X, y): + self.classes_ = [0, 1] + return self + + clf = MyClassifier().fit(X, y) + + with pytest.raises(ValueError, match=msg): + Display.from_estimator(clf, X, y, response_method=response_method) + + +@pytest.mark.parametrize("Display", [PrecisionRecallDisplay]) +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +def test_display_curve_estimator_name_multiple_calls( + pyplot, + data_binary, + Display, + constructor_name, +): + """Check that passing `name` when calling `plot` will overwrite the original name + in the legend.""" + X, y = data_binary + clf_name = "my hand-crafted name" + clf = LogisticRegression().fit(X, y) + y_pred = clf.predict_proba(X)[:, 1] + + # safe guard for the binary if/else construction + assert constructor_name in ("from_estimator", "from_predictions") + + if constructor_name == "from_estimator": + disp = Display.from_estimator(clf, X, y, name=clf_name) + else: + disp = Display.from_predictions(y, y_pred, name=clf_name) + assert disp.estimator_name == clf_name + pyplot.close("all") + disp.plot() + assert clf_name in disp.line_.get_label() + pyplot.close("all") + clf_name = "another_name" + disp.plot(name=clf_name) + assert clf_name in disp.line_.get_label() + + +@pytest.mark.parametrize( + "clf", + [ + LogisticRegression(), + make_pipeline(StandardScaler(), LogisticRegression()), + make_pipeline( + make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression() + ), + ], +) +@pytest.mark.parametrize("Display", [PrecisionRecallDisplay]) +def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): + """Check that a proper error is raised when the classifier is not + fitted.""" + X, y = data_binary + # clone since we parametrize the test and the classifier will be fitted + # when testing the second and subsequent plotting function + model = clone(clf) + with pytest.raises(NotFittedError): + Display.from_estimator(model, X, y) + model.fit(X, y) + disp = Display.from_estimator(model, X, y) + assert model.__class__.__name__ in disp.line_.get_label() + assert disp.estimator_name == model.__class__.__name__ diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index db4be93acc616..2a36b0cc565ec 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -1,7 +1,6 @@ import numpy as np import pytest -from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.compose import make_column_transformer from sklearn.datasets import load_breast_cancer, make_classification from sklearn.exceptions import NotFittedError @@ -76,47 +75,6 @@ def test_plot_precision_recall_curve_deprecation(pyplot): plot_precision_recall_curve(clf, X, y) -@pytest.mark.parametrize( - "response_method, msg", - [ - ( - "predict_proba", - "response method predict_proba is not defined in MyClassifier", - ), - ( - "decision_function", - "response method decision_function is not defined in MyClassifier", - ), - ( - "auto", - "response method decision_function or predict_proba is not " - "defined in MyClassifier", - ), - ( - "bad_method", - "response_method must be 'predict_proba', 'decision_function' or 'auto'", - ), - ], -) -def test_precision_recall_display_bad_response(pyplot, response_method, msg): - """Check that the proper error is raised when passing a `response_method` - not compatible with the estimator.""" - X, y = make_classification(n_classes=2, n_samples=50, random_state=0) - - class MyClassifier(ClassifierMixin, BaseEstimator): - def fit(self, X, y): - self.fitted_ = True - self.classes_ = [0, 1] - return self - - clf = MyClassifier().fit(X, y) - - with pytest.raises(ValueError, match=msg): - PrecisionRecallDisplay.from_estimator( - clf, X, y, response_method=response_method - ) - - @pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) def test_precision_recall_display_plotting(pyplot, constructor_name, response_method): From 36602d9b92634705dd4a6b76df60c61b6c6af6d8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 19:15:21 +0200 Subject: [PATCH 18/25] revert setup.cfg --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 107a8abff8182..8ee90da7436c0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,7 @@ addopts = --ignore maint_tools --ignore asv_benchmarks --doctest-modules - # --disable-pytest-warnings + --disable-pytest-warnings -rxXs filterwarnings = From 02251988590893b38fe0f46c8e00b1fc3fa0b28d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 19:44:52 +0200 Subject: [PATCH 19/25] simplify shape --- sklearn/metrics/_plot/precision_recall_curve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index f631bed27c3c2..66780c0042abc 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -22,10 +22,10 @@ class PrecisionRecallDisplay: Parameters ----------- - precision : ndarray of shape (n_samples,) + precision : ndarray Precision values. - recall : ndarray of shape (n_samples,) + recall : ndarray Recall values. average_precision : float, default=None From 8f4499671f687da6ece5e849cb92f4ba8e525c5a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Jul 2021 20:43:56 +0200 Subject: [PATCH 20/25] consistency --- sklearn/metrics/_plot/precision_recall_curve.py | 5 ++++- sklearn/metrics/_plot/tests/test_precision_recall_display.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 66780c0042abc..c8f45b10fa343 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -284,7 +284,8 @@ def from_predictions( precision and recall metrics. name : str, default=None - Name for labeling curve. If `None`, no name is used. + Name for labeling curve. If `None`, name will be set to + `"Classifier"`. ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. @@ -332,6 +333,8 @@ def from_predictions( y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight ) + name = name if name is not None else "Classifier" + viz = PrecisionRecallDisplay( precision=precision, recall=recall, diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 2a36b0cc565ec..e0a464b88c0a3 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -127,7 +127,7 @@ def test_precision_recall_display_plotting(pyplot, constructor_name, response_me "constructor_name, default_label", [ ("from_estimator", "LogisticRegression (AP = {:.2f})"), - ("from_predictions", "AP = {:.2f}"), + ("from_predictions", "Classifier (AP = {:.2f})"), ], ) def test_precision_recall_display_name(pyplot, constructor_name, default_label): From 6a67f39b22d47f81eac7ca3eef36b929b0d9632b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 20 Jul 2021 00:24:14 +0200 Subject: [PATCH 21/25] iter --- sklearn/metrics/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index 514026238aaee..5640848b1a9d4 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -246,6 +246,6 @@ def _check_pos_label_consistency(pos_label, y_true): "{-1, 1} or pass pos_label explicitly." ) elif pos_label is None: - pos_label = 1.0 + pos_label = 1 return pos_label From cd9a82cec4093c3ce7ad31cc651c88ab0d20ae75 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 20 Jul 2021 00:49:08 +0200 Subject: [PATCH 22/25] add comment tweek --- sklearn/metrics/_plot/tests/test_precision_recall_display.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index e0a464b88c0a3..7170ffa6347c3 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -260,7 +260,9 @@ def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_meth assert classifier.classes_.tolist() == ["cancer", "not cancer"] y_pred = getattr(classifier, response_method)(X_test) - y_pred_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 0] + # we select the correcponding probability columns or reverse the decision + # function otherwise + y_pred_cancer = -1 * y_pred if y_pred.ndim == 1 else y_pred[:, 0] y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1] if constructor_name == "from_estimator": From e764bf28bc9b9813a941808457299663448b0abd Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 6 Aug 2021 18:03:52 +0200 Subject: [PATCH 23/25] Apply suggestions from code review Co-authored-by: Roman Yurchak --- examples/model_selection/plot_precision_recall.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/model_selection/plot_precision_recall.py b/examples/model_selection/plot_precision_recall.py index 52bf50c7069c7..c0f0a97dd44ce 100644 --- a/examples/model_selection/plot_precision_recall.py +++ b/examples/model_selection/plot_precision_recall.py @@ -135,9 +135,9 @@ # classifier or not. # # Let's first plot the precision-recall curve without the classifier -# predictions. Thus, we should use -# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` that will -# compute the predictions for us before to plot the curve. +# predictions. We use +# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` that +# computes the predictions for us before plotting the curve. from sklearn.metrics import PrecisionRecallDisplay display = PrecisionRecallDisplay.from_estimator( @@ -146,7 +146,7 @@ _ = display.ax_.set_title("2-class Precision-Recall curve") # %% -# In the case, that we already got the estimated probabilities or scores for +# If we already got the estimated probabilities or scores for # our model, then we can use # :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions`. y_score = classifier.decision_function(X_test) From 1f38f5c9dfd2aaab038e2971cef6e4a385595d74 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 6 Aug 2021 23:18:11 +0200 Subject: [PATCH 24/25] fix doc --- doc/modules/model_evaluation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index a36c45bb06ca5..7d6010f57c568 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -437,7 +437,7 @@ In the multilabel case with binary label indicators:: .. topic:: Example: - * See :ref:`sphx_glr_auto_examples_feature_selection_plot_permutation_test_for_classification.py` + * See :ref:`sphx_glr_auto_examples_model_selection_plot_permutation_test_for_classification.py` for an example of accuracy score usage using permutations of the dataset. From 7d013ffe909d43f882dd0a70656dfde6b6fc2dc2 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 9 Aug 2021 18:41:50 +0200 Subject: [PATCH 25/25] update error message --- sklearn/metrics/_plot/tests/test_precision_recall_display.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 7170ffa6347c3..165e2b75df36e 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -21,7 +21,7 @@ ) -def test_confusion_matrix_display_validation(pyplot): +def test_precision_recall_display_validation(pyplot): """Check that we raise the proper error when validating parameters.""" X, y = make_classification( n_samples=100, n_informative=5, n_classes=5, random_state=0 @@ -39,7 +39,7 @@ def test_confusion_matrix_display_validation(pyplot): with pytest.raises(ValueError, match=err_msg): PrecisionRecallDisplay.from_estimator(regressor, X, y) - err_msg = "SVC should be a binary classifier" + err_msg = "Expected 'estimator' to be a binary classifier, but got SVC" with pytest.raises(ValueError, match=err_msg): PrecisionRecallDisplay.from_estimator(classifier, X, y)