From 31c92557944956a6493b2385f929838e113a9e20 Mon Sep 17 00:00:00 2001 From: "wishyut.pitawanik" Date: Sat, 25 Feb 2023 00:24:25 +0100 Subject: [PATCH 1/2] added parameter validation testing metrics.precision_recall_curve in test_public_functions --- sklearn/metrics/_ranking.py | 8 ++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 9 insertions(+) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 297e83173e47e..be55bb9c999b0 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -814,6 +814,14 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): return fps, tps, y_score[threshold_idxs] +@validate_params( + { + "y_true": ["array-like"], + "probas_pred": ["array-like"], + "pos_label": [Real, str, "boolean", None], + "sample_weight": ["array-like", None], + } +) def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight=None): """Compute precision-recall pairs for different probability thresholds. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 4e13bb46ef645..a9cb675f43423 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -134,6 +134,7 @@ def _check_function_param_validation( "sklearn.metrics.multilabel_confusion_matrix", "sklearn.metrics.mutual_info_score", "sklearn.metrics.pairwise.additive_chi2_kernel", + "sklearn.metrics.precision_recall_curve", "sklearn.metrics.precision_recall_fscore_support", "sklearn.metrics.r2_score", "sklearn.metrics.roc_curve", From c0a1c20dcc46f71d976f74b93536f549f496e19d Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 27 Feb 2023 11:23:09 +0100 Subject: [PATCH 2/2] fix docstring --- sklearn/metrics/_ranking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index be55bb9c999b0..e3d46f5138fb2 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -847,11 +847,11 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight Parameters ---------- - y_true : ndarray of shape (n_samples,) + y_true : array-like of shape (n_samples,) True binary labels. If labels are not either {-1, 1} or {0, 1}, then pos_label should be explicitly given. - probas_pred : ndarray of shape (n_samples,) + probas_pred : array-like of shape (n_samples,) Target scores, can either be probability estimates of the positive class, or non-thresholded measure of decisions (as returned by `decision_function` on some classifiers).