From f772ddada0b788fd6bc54bb6ea44b512838fecee Mon Sep 17 00:00:00 2001 From: zeeshan Date: Mon, 27 Feb 2023 01:46:48 +0530 Subject: [PATCH 1/2] Added parameter validation for metrics.precision_score, added test for metrics.precision_score in tes_public_functions.py --- sklearn/metrics/_classification.py | 17 +++++++++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 18 insertions(+) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index f7d002ac23743..ac488ac3af772 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1904,6 +1904,23 @@ class after being classified as negative. This is the case when the return positive_likelihood_ratio, negative_likelihood_ratio +@validate_params( + { + "y_true": ["array-like", "sparse matrix"], + "y_pred": ["array-like", "sparse matrix"], + "labels": ["array-like", None], + "pos_label": [Real, str, "boolean", None], + "average": [ + StrOptions({"micro", "macro", "samples", "weighted", "binary"}), + None, + ], + "sample_weight": ["array-like", None], + "zero_division": [ + Options(Real, {0, 1}), + StrOptions({"warn"}), + ], + } +) def precision_score( y_true, y_pred, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 4e13bb46ef645..f52a9755dc016 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -141,6 +141,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", + "sklearn.metrics.precision_score", ] From 7442e9c2d43fd88891d01c0dad4fa560e92e8187 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 27 Feb 2023 11:27:38 +0100 Subject: [PATCH 2/2] keep alphabetical order --- sklearn/tests/test_public_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index f52a9755dc016..665fc8c7af98b 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -135,13 +135,13 @@ def _check_function_param_validation( "sklearn.metrics.mutual_info_score", "sklearn.metrics.pairwise.additive_chi2_kernel", "sklearn.metrics.precision_recall_fscore_support", + "sklearn.metrics.precision_score", "sklearn.metrics.r2_score", "sklearn.metrics.roc_curve", "sklearn.metrics.zero_one_loss", "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", - "sklearn.metrics.precision_score", ]