8000 MAINT Parameters validation for metrics.ndcg_score (#25885) · jeremiedbb/scikit-learn@5a17650 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit 5a17650

Browse files
MAINT Parameters validation for metrics.ndcg_score (scikit-learn#25885)
1 parent f151833 commit 5a17650

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

sklearn/metrics/_ranking.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,6 +1620,15 @@ def _ndcg_sample_scores(y_true, y_score, k=None, ignore_ties=False):
16201620
return gain
16211621

16221622

1623+
@validate_params(
1624+
{
1625+
"y_true": ["array-like"],
1626+
"y_score": ["array-like"],
1627+
"k": [Interval(Integral, 1, None, closed="left"), None],
1628+
"sample_weight": ["array-like", None],
1629+
"ignore_ties": ["boolean"],
1630+
}
1631+
)
16231632
def ndcg_score(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False):
16241633
"""Compute Normalized Discounted Cumulative Gain.
16251634
@@ -1633,15 +1642,15 @@ def ndcg_score(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False
16331642
16341643
Parameters
16351644
----------
1636-
y_true : ndarray of shape (n_samples, n_labels)
1645+
y_true : array-like of shape (n_samples, n_labels)
16371646
True targets of multilabel classification, or true scores of entities
16381647
to be ranked. Negative values in `y_true` may result in an output
16391648
that is not between 0 and 1.
16401649
16411650
.. versionchanged:: 1.2
16421651
These negative values are deprecated, and will raise an error in v1.4.
16431652
1644-
y_score : ndarray of shape (n_samples, n_labels)
1653+
y_score : array-like of shape (n_samples, n_labels)
16451654
Target scores, can either be probability estimates, confidence values,
16461655
or non-thresholded measure of decisions (as returned by
16471656
"decision_function" on some classifiers).
@@ -1650,7 +1659,7 @@ def ndcg_score(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False
16501659
Only consider the highest k scores in the ranking. If `None`, use all
16511660
outputs.
16521661
1653-
sample_weight : ndarray of shape (n_samples,), default=None
1662+
sample_weight : array-like of shape (n_samples,), default=None
16541663
Sample weights. If `None`, all samples are given the same weight.
16551664
16561665
ignore_ties : bool, default=False

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def _check_function_param_validation(
180180
"sklearn.metrics.median_absolute_error",
181181
"sklearn.metrics.multilabel_confusion_matrix",
182182
"sklearn.metrics.mutual_info_score",
183+
"sklearn.metrics.ndcg_score",
183184
"sklearn.metrics.pairwise.additive_chi2_kernel",
184185
"sklearn.metrics.precision_recall_curve",
185186
"sklearn.metrics.precision_recall_fscore_support",

0 commit comments

Comments
 (0)
0