diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index c351320a48278..b19fc2e7f3f70 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -1238,6 +1238,17 @@ def mean_gamma_deviance(y_true, y_pred, *, sample_weight=None): return mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=2) +@validate_params( + { + "y_true": ["array-like"], + "y_pred": ["array-like"], + "sample_weight": ["array-like", None], + "power": [ + Interval(Real, None, 0, closed="right"), + Interval(Real, 1, None, closed="left"), + ], + } +) def d2_tweedie_score(y_true, y_pred, *, sample_weight=None, power=0): """D^2 regression score function, fraction of Tweedie deviance explained. @@ -1257,7 +1268,7 @@ def d2_tweedie_score(y_true, y_pred, *, sample_weight=None, power=0): y_pred : array-like of shape (n_samples,) Estimated target values. - sample_weight : array-like of shape (n_samples,), optional + sample_weight : array-like of shape (n_samples,), default=None Sample weights. power : float, default=0 diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 4b1934c378fbf..de693dc4971e2 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -161,6 +161,7 @@ def _check_function_param_validation( "sklearn.metrics.confusion_matrix", "sklearn.metrics.coverage_error", "sklearn.metrics.d2_pinball_score", + "sklearn.metrics.d2_tweedie_score", "sklearn.metrics.dcg_score", "sklearn.metrics.det_curve", "sklearn.metrics.f1_score",