From 70bb286a70e303cd71dfc015889e1edf38d13832 Mon Sep 17 00:00:00 2001 From: zeeshan Date: Sat, 18 Mar 2023 13:21:18 +0530 Subject: [PATCH 1/2] Added Parameter Validation for cluster.adjusted_mutual_info_score --- sklearn/metrics/cluster/_supervised.py | 9 ++++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/cluster/_supervised.py b/sklearn/metrics/cluster/_supervised.py index 1cd7f6c5ce80e..f55f953772693 100644 --- a/sklearn/metrics/cluster/_supervised.py +++ b/sklearn/metrics/cluster/_supervised.py @@ -27,7 +27,7 @@ from ...utils.multiclass import type_of_target from ...utils.validation import check_array, check_consistent_length from ...utils._param_validation import validate_params -from ...utils._param_validation import Interval +from ...utils._param_validation import Interval, StrOptions def check_clusterings(labels_true, labels_pred): @@ -847,6 +847,13 @@ def mutual_info_score(labels_true, labels_pred, *, contingency=None): return np.clip(mi.sum(), 0.0, None) +@validate_params( + { + "labels_true": ["array-like"], + "labels_pred": ["array-like"], + "average_method": [StrOptions({"arithmetic", "max", "min", "geometric"})], + } +) def adjusted_mutual_info_score( labels_true, labels_pred, *, average_method="arithmetic" ): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index f72d622e53902..44480dd869d68 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -152,6 +152,7 @@ def _check_function_param_validation( "sklearn.metrics.brier_score_loss", "sklearn.metrics.class_likelihood_ratios", "sklearn.metrics.classification_report", + "sklearn.metrics.cluster.adjusted_mutual_info_score", "sklearn.metrics.cluster.contingency_matrix", "sklearn.metrics.cohen_kappa_score", "sklearn.metrics.confusion_matrix", From b1936572c44bd93a26e2b1229b1a31f40af9cbfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Tue, 21 Mar 2023 12:49:30 +0100 Subject: [PATCH 2/2] cln docstring --- sklearn/metrics/cluster/_supervised.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/cluster/_supervised.py b/sklearn/metrics/cluster/_supervised.py index f55f953772693..d05ccec33beb6 100644 --- a/sklearn/metrics/cluster/_supervised.py +++ b/sklearn/metrics/cluster/_supervised.py @@ -883,7 +883,7 @@ def adjusted_mutual_info_score( Parameters ---------- - labels_true : int array, shape = [n_samples] + labels_true : int array-like of shape (n_samples,) A clustering of the data into disjoint subsets, called :math:`U` in the above formula. @@ -891,9 +891,8 @@ def adjusted_mutual_info_score( A clustering of the data into disjoint subsets, called :math:`V` in the above formula. - average_method : str, default='arithmetic' - How to compute the normalizer in the denominator. Possible options - are 'min', 'geometric', 'arithmetic', and 'max'. + average_method : {'min', 'geometric', 'arithmetic', 'max'}, default='arithmetic' + How to compute the normalizer in the denominator. .. versionadded:: 0.20