diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 4b8202d60d53e..4243bccd23c9a 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -21,7 +21,7 @@ import warnings from functools import partial -from numbers import Real +from numbers import Real, Integral import numpy as np from scipy.sparse import csr_matrix, issparse @@ -34,7 +34,7 @@ from ..utils.multiclass import type_of_target from ..utils.extmath import stable_cumsum from ..utils.sparsefuncs import count_nonzero -from ..utils._param_validation import validate_params, StrOptions +from ..utils._param_validation import validate_params, StrOptions, Interval from ..exceptions import UndefinedMetricWarning from ..preprocessing import label_binarize from ..utils._encode import _encode, _unique @@ -1445,6 +1445,16 @@ def _check_dcg_target_type(y_true): ) +@validate_params( + { + "y_true": ["array-like"], + "y_score": ["array-like"], + "k": [Interval(Integral, 1, None, closed="left"), None], + "log_base": [Interval(Real, 0.0, None, closed="neither")], + "sample_weight": ["array-like", None], + "ignore_ties": ["boolean"], + } +) def dcg_score( y_true, y_score, *, k=None, log_base=2, sample_weight=None, ignore_ties=False ): @@ -1461,11 +1471,11 @@ def dcg_score( Parameters ---------- - y_true : ndarray of shape (n_samples, n_labels) + y_true : array-like of shape (n_samples, n_labels) True targets of multilabel classification, or true scores of entities to be ranked. - y_score : ndarray of shape (n_samples, n_labels) + y_score : array-like of shape (n_samples, n_labels) Target scores, can either be probability estimates, confidence values, or non-thresholded measure of decisions (as returned by "decision_function" on some classifiers). @@ -1478,7 +1488,7 @@ def dcg_score( Base of the logarithm used for the discount. A low value means a sharper discount (top results are more important). - sample_weight : ndarray of shape (n_samples,), default=None + sample_weight : array-like of shape (n_samples,), default=None Sample weights. If `None`, all samples are given the same weight. ignore_ties : bool, default=False diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 836139b85b341..cbe75a57a3705 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -124,6 +124,7 @@ def _check_function_param_validation( "sklearn.metrics.confusion_matrix", "sklearn.metrics.coverage_error", "sklearn.metrics.d2_pinball_score", + "sklearn.metrics.dcg_score", "sklearn.metrics.det_curve", "sklearn.metrics.f1_score", "sklearn.metrics.hamming_loss",