8000 MAINT Parameters validation for metrics.dcg_score (#25749) · scikit-learn/scikit-learn@4180b07 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4180b07

Browse files
MAINT Parameters validation for metrics.dcg_score (#25749)
1 parent 73c17de commit 4180b07

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

sklearn/metrics/_ranking.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import warnings
2323
from functools import partial
24-
from numbers import Real
24+
from numbers import Real, Integral
2525

2626
import numpy as np
2727
from scipy.sparse import csr_matrix, issparse
@@ -34,7 +34,7 @@
3434
from ..utils.multiclass import type_of_target
3535
from ..utils.extmath import stable_cumsum
3636
from ..utils.sparsefuncs import count_nonzero
37-
from ..utils._param_validation import validate_params, StrOptions
37+
from ..utils._param_validation import validate_params, StrOptions, Interval
3838
from ..exceptions import UndefinedMetricWarning
3939
from ..preprocessing import label_binarize
4040
from ..utils._encode import _encode, _unique
@@ -1445,6 +1445,16 @@ def _check_dcg_target_type(y_true):
14451445
)
14461446

14471447

1448+
@validate_params(
1449+
{
1450+
"y_true": ["array-like"],
1451+
"y_score": ["array-like"],
1452+
"k": [Interval(Integral, 1, None, closed="left"), None],
1453+
"log_base": [Interval(Real, 0.0, None, closed="neither")],
1454+
"sample_weight": ["array-like", None],
1455+
"ignore_ties": ["boolean"],
1456+
}
1457+
)
14481458
def dcg_score(
14491459
y_true, y_score, *, k=None, log_base=2, sample_weight=None, ignore_ties=False
14501460
):
@@ -1461,11 +1471,11 @@ def dcg_score(
14611471
14621472
Parameters
14631473
----------
1464-
y_true : ndarray of shape (n_samples, n_labels)
1474+
y_true : array-like of shape (n_samples, n_labels)
14651475
True targets of multilabel classification, or true scores of entities
14661476
to be ranked.
14671477
1468-
y_score : ndarray of shape (n_samples, n_labels)
1478+
y_score : array-like of shape (n_samples, n_labels)
14691479
Target scores, can either be probability estimates, confidence values,
14701480
or non-thresholded measure of decisions (as returned by
14711481
"decision_function" on some classifiers).
@@ -1478,7 +1488,7 @@ def dcg_score(
14781488
Base of the logarithm used for the discount. A low value means a
14791489
sharper discount (top results are more important).
14801490
1481-
sample_weight : ndarray of shape (n_samples,), default=None
1491+
sample_weight : array-like of shape (n_samples,), default=None
14821492
Sample weights. If `None`, all samples are given the same weight.
14831493
14841494
ignore_ties : bool, default=False

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def _check_function_param_validation(
124124
"sklearn.metrics.confusion_matrix",
125125
"sklearn.metrics.coverage_error",
126126
"sklearn.metrics.d2_pinball_score",
127+
"sklearn.metrics.dcg_score",
127128
"sklearn.metrics.det_curve",
128129
"sklearn.metrics.f1_score",
129130
"sklearn.metrics.hamming_loss",

0 commit comments

Comments
 (0)
0