8000 MAINT Parameters validation for sklearn.metrics.adjusted_rand_score (… · scikit-learn/scikit-learn@3023f19 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3023f19

Browse files
MAINT Parameters validation for sklearn.metrics.adjusted_rand_score (#26134)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent c8fb561 commit 3023f19

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

sklearn/metrics/cluster/_supervised.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,12 @@ def rand_score(labels_true, labels_pred):
325325
return numerator / denominator
326326

327327

328+
@validate_params(
329+
{
330+
"labels_true": ["array-like"],
331+
"labels_pred": ["array-like"],
332+
}
333+
)
328334
def adjusted_rand_score(labels_true, labels_pred):
329335
"""Rand index adjusted for chance.
330336
@@ -352,10 +358,10 @@ def adjusted_rand_score(labels_true, labels_pred):
352358
353359
Parameters
354360
----------
355-
labels_true : int array, shape = [n_samples]
361+
labels_true : array-like of shape (n_samples,), dtype=int
356362
Ground truth class labels to be used as a reference.
357363
358-
labels_pred : array-like of shape (n_samples,)
364+
labels_pred : array-like of shape (n_samples,), dtype=int
359365
Cluster labels to evaluate.
360366
361367
Returns

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def _check_function_param_validation(
209209
"sklearn.metrics.mutual_info_score",
210210
"sklearn.metrics.ndcg_score",
211211
"sklearn.metrics.pair_confusion_matrix",
212+
"sklearn.metrics.adjusted_rand_score",
212213
"sklearn.metrics.pairwise.additive_chi2_kernel",
213214
"sklearn.metrics.pairwise.cosine_distances",
214215
"sklearn.metrics.pairwise.cosine_similarity",

0 commit comments

Comments
 (0)
0