8000 FEA add d2_tweedie_score (#17036) · baam25simo/scikit-learn@9061ff9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9061ff9

Browse files
authored
FEA add d2_tweedie_score (scikit-learn#17036)
1 parent 767d0a4 commit 9061ff9

File tree

7 files changed

+195
-18
lines changed

7 files changed

+195
-18
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,7 @@ details.
994994
metrics.mean_poisson_deviance
995995
metrics.mean_gamma_deviance
996996
metrics.mean_tweedie_deviance
997+
metrics.d2_tweedie_score
997998
metrics.mean_pinball_loss
998999

9991000
Multilabel ranking metrics

doc/modules/model_evaluation.rst

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2355,6 +2355,34 @@ the difference in errors decreases. Finally, by setting, ``power=2``::
23552355
we would get identical errors. The deviance when ``power=2`` is thus only
23562356
sensitive to relative errors.
23572357

2358+
.. _d2_tweedie_score:
2359+
2360+
D² score, the coefficient of determination
2361+
-------------------------------------------
2362+
2363+
The :func:`d2_tweedie_score` function computes the percentage of deviance
2364+
explained. It is a generalization of R², where the squared error is replaced by
2365+
the Tweedie deviance. D², also known as McFadden's likelihood ratio index, is
2366+
calculated as
2367+
2368+
.. math::
2369+
2370+
D^2(y, \hat{y}) = 1 - \frac{\text{D}(y, \hat{y})}{\text{D}(y, \bar{y})} \,.
2371+
2372+
The argument ``power`` defines the Tweedie power as for
2373+
:func:`mean_tweedie_deviance`. Note that for `power=0`,
2374+
:func:`d2_tweedie_score` equals :func:`r2_score` (for single targets).
2375+
2376+
Like R², the best possible score is 1.0 and it can be negative (because the
2377+
model can be arbitrarily worse). A constant model that always predicts the
2378+
expected value of y, disregarding the input features, would get a D² score
2379+
of 0.0.
2380+
2381+
A scorer object with a specific choice of ``power`` can be built by::
2382+
2383+
>>> from sklearn.metrics import d2_tweedie_score, make_scorer
2384+
>>> d2_tweedie_score_15 = make_scorer(d2_tweedie_score, pwoer=1.5)
2385+
23582386
.. _pinball_loss:
23592387

23602388
Pinball loss
@@ -2387,7 +2415,7 @@ Here is a small example of usage of the :func:`mean_pinball_loss` function::
23872415
>>> mean_pinball_loss(y_true, y_true, alpha=0.9)
23882416
0.0
23892417

2390-
It is possible to build a scorer object with a specific choice of alpha::
2418+
It is possible to build a scorer object with a specific choice of ``alpha``::
23912419

23922420
>>> from sklearn.metrics import make_scorer
23932421
>>> mean_pinball_loss_95p = make_scorer(mean_pinball_loss, alpha=0.95)

doc/whats_new/v1.0.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,12 @@ Changelog
601601
quantile regression. :pr:`19415` by :user:`Xavier Dupré <sdpython>`
602602
and :user:`Oliver Grisel <ogrisel>`.
603603

604+
- |Feature| :func:`metrics.d2_tweedie_score` calculates the D^2 regression
605+
score for Tweedie deviances with power parameter ``power``. This is a
606+
generalization of the `r2_score` and can be interpreted as percentage of
607+
Tweedie deviance explained.
608+
:pr:`17036` by :user:`Christian Lorentzen <lorentzenchr>`.
609+
604610
- |Feature| :func:`metrics.mean_squared_log_error` now supports
605611
`squared=False`.
606612
:pr:`20326` by :user:`Uttam kumar <helper-uttam>`.
@@ -718,7 +724,7 @@ Changelog
718724
.............................
719725

720726
- |Fix| :class:`neural_network.MLPClassifier` and
721-
:class:`neural_network.MLPRegressor` now correct supports continued training
727+
:class:`neural_network.MLPRegressor` now correctly support continued training
722728
when loading from a pickled file. :pr:`19631` by `Thomas Fan`_.
723729

724730
:mod:`sklearn.pipeline`

sklearn/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from ._regression import mean_tweedie_deviance
7575
from ._regression import mean_poisson_deviance
7676
from ._regression import mean_gamma_deviance
77+
from ._regression import d2_tweedie_score
7778

7879

7980
from ._scorer import check_scoring
@@ -109,6 +110,7 @@
109110
"confusion_matrix",
110111
"consensus_score",
111112
"coverage_error",
113+
"d2_tweedie_score",
112114
"dcg_score",
113115
"davies_bouldin_score",
114116
"DetCurveDisplay",

sklearn/metrics/_regression.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@
2424
# Uttam kumar <bajiraouttamsinha@gmail.com>
2525
# License: BSD 3 clause
2626

27-
import numpy as np
2827
import warnings
2928

29+
import numpy as np
30+
3031
from .._loss.glm_distribution import TweedieDistribution
32+
from ..exceptions import UndefinedMetricWarning
3133
from ..utils.validation import check_array, check_consistent_length, _num_samples
3234
from ..utils.validation import column_or_1d
3335
from ..utils.validation import _check_sample_weight
3436
from ..utils.stats import _weighted_percentile
35-
from ..exceptions import UndefinedMetricWarning
3637

3738

3839
__ALL__ = [
@@ -986,3 +987,107 @@ def mean_gamma_deviance(y_true, y_pred, *, sample_weight=None):
986987
1.0568...
987988
"""
988989
return mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=2)
990+
991+
992+
def d2_tweedie_score(y_true, y_pred, *, sample_weight=None, power=0):
993+
"""D^2 regression score function, percentage of Tweedie deviance explained.
994+
995+
Best possible score is 1.0 and it can be negative (because the model can be
996+
arbitrarily worse). A model that always uses the empirical mean of `y_true` as
997+
constant prediction, disregarding the input features, gets a D^2 score of 0.0.
998+
999+
Read more in the :ref:`User Guide <d2_tweedie_score>`.
1000+
1001+
.. versionadded:: 1.0
1002+
1003+
Parameters
1004+
----------
1005+
y_true : array-like of shape (n_samples,)
1006+
Ground truth (correct) target values.
1007+
1008+
y_pred : array-like of shape (n_samples,)
1009+
Estimated target values.
1010+
1011+
sample_weight : array-like of shape (n_samples,), optional
1012+
Sample weights.
1013+
1014+
power : float, default=0
1015+
Tweedie power parameter. Either power <= 0 or power >= 1.
1016+
1017+
The higher `p` the less weight is given to extreme
1018+
deviations between true and predicted targets.
1019+
1020+
- power < 0: Extreme stable distribution. Requires: y_pred > 0.
1021+
- power = 0 : Normal distribution, output corresponds to r2_score.
1022+
y_true and y_pred can be any real numbers.
1023+
- power = 1 : Poisson distribution. Requires: y_true >= 0 and
1024+
y_pred > 0.
1025+
- 1 < p < 2 : Compound Poisson distribution. Requires: y_true >= 0
1026+
and y_pred > 0.
1027+
- power = 2 : Gamma distribution. Requires: y_true > 0 and y_pred > 0.
1028+
- power = 3 : Inverse Gaussian distribution. Requires: y_true > 0
1029+
and y_pred > 0.
1030+
- otherwise : Positive stable distribution. Requires: y_true > 0
1031+
and y_pred > 0.
1032+
1033+
Returns
1034+
-------
1035+
z : float or ndarray of floats
1036+
The D^2 score.
1037+
1038+
Notes
1039+
-----
1040+
This is not a symmetric function.
1041+
1042+
Like R^2, D^2 score may be negative (it need not actually be the square of
1043+
a quantity D).
1044+
1045+
This metric is not well-defined for single samples and will return a NaN
1046+
value if n_samples is less than two.
1047+
1048+
References
1049+
----------
1050+
.. [1] Eq. (3.11) of Hastie, Trevor J., Robert Tibshirani and Martin J.
1051+
Wainwright. "Statistical Learning with Sparsity: The Lasso and
1052+
Generalizations." (2015). https://trevorhastie.github.io
1053+
1054+
Examples
1055+
--------
1056+
>>> from sklearn.metrics import d2_tweedie_score
1057+
>>> y_true = [0.5, 1, 2.5, 7]
1058+
>>> y_pred = [1, 1, 5, 3.5]
1059+
>>> d2_tweedie_score(y_true, y_pred)
1060+
0.285...
1061+
>>> d2_tweedie_score(y_true, y_pred, power=1)
1062+
0.487...
1063+
>>> d2_tweedie_score(y_true, y_pred, power=2)
1064+
0.630...
1065+
>>> d2_tweedie_score(y_true, y_true, power=2)
1066+
1.0
1067+
"""
1068+
y_type, y_true, y_pred, _ = _check_reg_targets(
1069+
y_true, y_pred, None, dtype=[np.float64, np.float32]
1070+
)
1071+
if y_type == "continuous-multioutput":
1072+
raise ValueError("Multioutput not supported in d2_tweedie_score")
1073+
check_consistent_length(y_true, y_pred, sample_weight)
1074+
1075+
if _num_samples(y_pred) < 2:
1076+
msg = "D^2 score is not well-defined with less than two samples."
1077+
warnings.warn(msg, UndefinedMetricWarning)
1078+
return float("nan")
1079+
1080+
if sample_weight is not None:
1081+
sample_weight = column_or_1d(sample_weight)
1082+
sample_weight = sample_weight[:, np.newaxis]
1083+
1084+
dist = TweedieDistribution(power=power)
1085+
1086+
dev = dist.unit_deviance(y_true, y_pred, check_input=True)
1087+
numerator = np.average(dev, weights=sample_weight)
1088+
1089+
y_avg = np.average(y_true, weights=sample_weight)
1090+
dev = dist.unit_deviance(y_true, y_avg, check_input=True)
1091+
denominator = np.average(dev, weights=sample_weight)
1092+
1093+
return 1 - numerator / denominator

sklearn/metrics/tests/test_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sklearn.metrics import cohen_kappa_score
3030
from sklearn.metrics import confusion_matrix
3131
from sklearn.metrics import coverage_error
32+
from sklearn.metrics import d2_tweedie_score
3233
from sklearn.metrics import det_curve
3334
from sklearn.metrics import explained_variance_score
3435
from sklearn.metrics import f1_score
@@ -110,6 +111,7 @@
110111
"mean_poisson_deviance": mean_poisson_deviance,
111112
"mean_gamma_deviance": mean_gamma_deviance,
112113
"mean_compound_poisson_deviance": partial(mean_tweedie_deviance, power=1.4),
114+
"d2_tweedie_score": partial(d2_tweedie_score, power=1.4),
113115
}
114116

115117
CLASSIFICATION_METRICS = {
@@ -510,6 +512,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
510512
"mean_gamma_deviance",
511513
"mean_poisson_deviance",
512514
"mean_compound_poisson_deviance",
515+
"d2_tweedie_score",
513516
"mean_absolute_percentage_error",
514517
}
515518

@@ -526,6 +529,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
526529
"mean_poisson_deviance",
527530
"mean_gamma_deviance",
528531
"mean_compound_poisson_deviance",
532+
"d2_tweedie_score",
529533
}
530534

531535

sklearn/metrics/tests/test_regression.py

Lines changed: 45 additions & 14 deletions
C2EE
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from scipy import optimize
33
from numpy.testing import assert_allclose
4+
from scipy.special import factorial, xlogy
45
from itertools import product
56
import pytest
67

@@ -20,6 +21,7 @@
2021
from sklearn.metrics import mean_pinball_loss
2122
from sklearn.metrics import r2_score
2223
from sklearn.metrics import mean_tweedie_deviance
24+
from sklearn.metrics import d2_tweedie_score
2325
from sklearn.metrics import make_scorer
2426

2527
from sklearn.metrics._regression import _check_reg_targets
@@ -53,6 +55,9 @@ def test_regression_metrics(n_samples=50):
5355
mean_tweedie_deviance(y_true, y_pred, power=0),
5456
mean_squared_error(y_true, y_pred),
5557
)
58+
assert_almost_equal(
59+
d2_tweedie_score(y_true, y_pred, power=0), r2_score(y_true, y_pred)
60+
)
5661

5762
# Tweedie deviance needs positive y_pred, except for p=0,
5863
# p>=2 needs positive y_true
@@ -78,6 +83,17 @@ def test_regression_metrics(n_samples=50):
7883
mean_tweedie_deviance(y_true, y_pred, power=3), np.sum(1 / y_true) / (4 * n)
7984
)
8085

86+
dev_mean = 2 * np.mean(xlogy(y_true, 2 * y_true / (n + 1)))
87+
assert_almost_equal(
88+
d2_tweedie_score(y_true, y_pred, power=1),
89+
1 - (n + 1) * (1 - np.log(2)) / dev_mean,
90+
)
91+
92+
dev_mean = 2 * np.log((n + 1) / 2) - 2 / n * np.log(factorial(n))
93+
assert_almost_equal(
94+
d2_tweedie_score(y_true, y_pred, power=2), 1 - (2 * np.log(2) - 1) / dev_mean
95+
)
96+
8197

8298
def test_mean_squared_error_multioutput_raw_value_squared():
8399
# non-regression test for
@@ -131,59 +147,74 @@ def test_regression_metrics_at_limits():
131147
assert_almost_equal(max_error([0.0], [0.0]), 0.0)
132148
assert_almost_equal(explained_variance_score([0.0], [0.0]), 1.0)
133149
assert_almost_equal(r2_score([0.0, 1], [0.0, 1]), 1.0)
134-
err_msg = (
150+
msg = (
135151
"Mean Squared Logarithmic Error cannot be used when targets "
136152
"contain negative values."
137153
)
138-
with pytest.raises(ValueError, match=err_msg):
154+
with pytest.raises(ValueError, match=msg):
139155
mean_squared_log_error([-1.0], [-1.0])
140-
err_msg = (
156+
msg = (
141157
"Mean Squared Logarithmic Error cannot be used when targets "
142158
"contain negative values."
143159
)
144-
with pytest.raises(ValueError, match=err_msg):
160+
with pytest.raises(ValueError, match=msg):
145161
mean_squared_log_error([1.0, 2.0, 3.0], [1.0, -2.0, 3.0])
146-
err_msg = (
162+
msg = (
147163
"Mean Squared Logarithmic Error cannot be used when targets "
148164
"contain negative values."
149165
)
150-
with pytest.raises(ValueError, match=err_msg):
166+
with pytest.raises(ValueError, match=msg):
151167
mean_squared_log_error([1.0, -2.0, 3.0], [1.0, 2.0, 3.0])
152168

153169
# Tweedie deviance error
154170
power = -1.2
155171
assert_allclose(
156172
mean_tweedie_deviance([0], [1.0], power=power), 2 / (2 - power), rtol=1e-3
157173
)
158-
with pytest.raises(
159-
ValueError, match="can only be used on strictly positive y_pred."
160-
):
174+
msg = "can only be used on strictly positive y_pred."
175+
with pytest.raises(ValueError, match=msg):
161176
mean_tweedie_deviance([0.0], [0.0], power=power)
162-
assert_almost_equal(mean_tweedie_deviance([0.0], [0.0], power=0), 0.00, 2)
177+
with pytest.raises(ValueError, match=msg):
178+
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
163179

180+
assert_almost_equal(mean_tweedie_deviance([0.0], [0.0], power=0), 0.0, 2)
181+
182+
power = 1.0
164183
msg = "only be used on non-negative y and strictly positive y_pred."
165184
with pytest.raises(ValueError, match=msg):
166-
mean_tweedie_deviance([0.0], [0.0], power=1.0)
185+
mean_tweedie_deviance([0.0], [0.0], power=power)
186+
with pytest.raises(ValueError, match=msg):
187+
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
167188

168189
power = 1.5
169190
assert_allclose(mean_tweedie_deviance([0.0], [1.0], power=power), 2 / (2 - power))
170191
msg = "only be used on non-negative y and strictly positive y_pred."
171192
with pytest.raises(ValueError, match=msg):
172193
mean_tweedie_deviance([0.0], [0.0], power=power)
194+
with pytest.raises(ValueError, match=msg):
195+
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
196+
173197
power = 2.0
174198
assert_allclose(mean_tweedie_deviance([1.0], [1.0], power=power), 0.00, atol=1e-8)
175199
msg = "can only be used on strictly positive y and y_pred."
176200
with pytest.raises(ValueError, match=msg):
177201
mean_tweedie_deviance([0.0], [0.0], power=power)
202+
with pytest.raises(ValueError, match=msg):
203+
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
204+
178205
power = 3.0
179206
assert_allclose(mean_tweedie_deviance([1.0], [1.0], power=power), 0.00, atol=1e-8)
180-
181207
msg = "can only be used on strictly positive y and y_pred."
182208
with pytest.raises(ValueError, match=msg):
183209
mean_tweedie_deviance([0.0], [0.0], power=power)
210+
with pytest.raises(ValueError, match=msg):
211+
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
184212

213+
power = 0.5
214+
with pytest.raises(ValueError, match="is only defined for power<=0 and power>=1"):
215+
mean_tweedie_deviance([0.0], [0.0], power=power)
185216
with pytest.raises(ValueError, match="is only defined for power<=0 and power>=1"):
186-
mean_tweedie_deviance([0.0], [0.0], power=0.5)
217+
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
187218

188219

189220
def test__check_reg_targets():
@@ -319,7 +350,7 @@ def test_regression_custom_weights():
319350
assert_almost_equal(msle, msle2, decimal=2)
320351

321352

322-
@pytest.mark.parametrize("metric", [r2_score])
353+
@pytest.mark.parametrize("metric", [r2_score, d2_tweedie_score])
323354
def test_regression_single_sample(metric):
324355
y_true = [0]
325356
y_pred = [1]

0 commit comments

Comments
 (0)
0