8000 ENH Add Poisson, Gamma and Tweedie deviances to regression metrics (#… · scikit-learn/scikit-learn@89da7f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 89da7f7

Browse files
Christian Lorentzenagramfort
authored andcommitted
ENH Add Poisson, Gamma and Tweedie deviances to regression metrics (#14263)
* ENH add new metric Tweedie deviance * More improvements * A few more fixes * Fix doctest * Review comments * Add common metric tests * Add what's new entry * Fix symmetry check in metrics common tests * Fix test order determinism * Rename metric to mean_tweedie_deviance * Address review comments * Fix test_score_objects.py tests * Add Poisson and Gamma deviances * Fix doc link * Fix typo on what's new * Fix rst rendering
1 parent fb169cd commit 89da7f7

File tree

9 files changed

+479
-50
lines changed

9 files changed

+479
-50
lines changed

doc/modules/classes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,9 @@ details.
903903
metrics.mean_squared_log_error
904904
metrics.median_absolute_error
905905
metrics.r2_score
906+
metrics.mean_poisson_deviance
907+
metrics.mean_gamma_deviance
908+
metrics.mean_tweedie_deviance
906909

907910
Multilabel ranking metrics
908911
--------------------------

doc/modules/model_evaluation.rst

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ Scoring Function
9191
'neg_mean_squared_log_error' :func:`metrics.mean_squared_log_error`
9292
'neg_median_absolute_error' :func:`metrics.median_absolute_error`
9393
'r2' :func:`metrics.r2_score`
94+
'neg_mean_poisson_deviance' :func:`metrics.mean_poisson_deviance`
95+
'neg_mean_gamma_deviance' :func:`metrics.mean_gamma_deviance`
9496
============================== ============================================= ==================================
9597

9698

@@ -1957,6 +1959,76 @@ Here is a small example of usage of the :func:`r2_score` function::
19571959
for an example of R² score usage to
19581960
evaluate Lasso and Elastic Net on sparse signals.
19591961

1962+
1963+
.. _mean_tweedie_deviance:
1964+
1965+
Mean Poisson, Gamma, and Tweedie deviances
1966+
------------------------------------------
1967+
The :func:`mean_tweedie_deviance` function computes the `mean Tweedie
1968+
deviance error
1969+
<https://en.wikipedia.org/wiki/Tweedie_distribution#The_Tweedie_deviance>`_
1970+
with power parameter `p`. This is a metric that elicits predicted expectation
1971+
values of regression targets.
1972+
1973+
Following special cases exist,
1974+
1975+
- when `p=0` it is equivalent to :func:`mean_squared_error`.
1976+
- when `p=1` it is equivalent to :func:`mean_poisson_deviance`.
1977+
- when `p=2` it is e A3E2 quivalent to :func:`mean_gamma_deviance`.
1978+
1979+
If :math:`\hat{y}_i` is the predicted value of the :math:`i`-th sample,
1980+
and :math:`y_i` is the corresponding true value, then the mean Tweedie
1981+
deviance error (D) estimated over :math:`n_{\text{samples}}` is defined as
1982+
1983+
.. math::
1984+
1985+
\text{D}(y, \hat{y}) = \frac{1}{n_\text{samples}}
1986+
\sum_{i=0}^{n_\text{samples} - 1}
1987+
\begin{cases}
1988+
(y_i-\hat{y}_i)^2, & \text{for }p=0\text{ (Normal)}\\
1989+
2(y_i \log(y/\hat{y}_i) + \hat{y}_i - y_i), & \text{for }p=1\text{ (Poisson)}\\
1990+
2(\log(\hat{y}_i/y_i) + y_i/\hat{y}_i - 1), & \text{for }p=2\text{ (Gamma)}\\
1991+
2\left(\frac{\max(y_i,0)^{2-p}}{(1-p)(2-p)}-
1992+
\frac{y\,\hat{y}^{1-p}_i}{1-p}+\frac{\hat{y}^{2-p}_i}{2-p}\right),
1993+
& \text{otherwise}
1994+
\end{cases}
1995+
1996+
Tweedie deviance is a homogeneous function of degree ``2-p``.
1997+
Thus, Gamma distribution with `p=2` means that simultaneously scaling `y_true`
1998+
and `y_pred` has no effect on the deviance. For Poisson distribution `p=1`
1999+
the deviance scales linearly, and for Normal distribution (`p=0`),
2000+
quadratically. In general, the higher `p` the less weight is given to extreme
2001+
deviations between true and predicted targets.
2002+
2003+
For instance, let's compare the two predictions 1.0 and 100 that are both
2004+
50% of their corresponding true value.
2005+
2006+
The mean squared error (``p=0``) is very sensitive to the
2007+
prediction difference of the second point,::
2008+
2009+
>>> from sklearn.metrics import mean_tweedie_deviance
2010+
>>> mean_tweedie_deviance([1.0], [1.5], p=0)
2011+
0.25
2012+
>>> mean_tweedie_deviance([100.], [150.], p=0)
2013+
2500.0
2014+
2015+
If we increase ``p`` to 1,::
2016+
2017+
>>> mean_tweedie_deviance([1.0], [1.5], p=1)
2018+
0.18...
2019+
>>> mean_tweedie_deviance([100.], [150.], p=1)
2020+
18.9...
2021+
2022+
the difference in errors decreases. Finally, by setting, ``p=2``::
2023+
2024+
>>> mean_tweedie_deviance([1.0], [1.5], p=2)
2025+
0.14...
2026+
>>> mean_tweedie_deviance([100.], [150.], p=2)
2027+
0.14...
2028+
2029+
we would get identical errors. The deviance when `p=2` is thus only
2030+
sensitive to relative errors.
2031+
19602032
.. _clustering_metrics:
19612033

19622034
Clustering metrics

doc/whats_new/v0.22.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ Changelog
131131
- |Feature| Added multiclass support to :func:`metrics.roc_auc_score`.
132132
:issue:`12789` by :user:`Kathy Chen <kathyxchen>`,
133133
:user:`Mohamed Maskani <maskani-moh>`, and :user:`Thomas Fan <thomasjpfan>`.
134+
135+
- |Feature| Add :class:`metrics.mean_tweedie_deviance` measuring the
136+
Tweedie deviance for a power parameter ``p``. Also add mean Poisson deviance
137+
:class:`metrics.mean_poisson_deviance` and mean Gamma deviance
138+
:class:`metrics.mean_gamma_deviance` that are special cases of the Tweedie
139+
deviance for `p=1` and `p=2` respectively.
140+
:pr:`13938` by :user:`Christian Lorentzen <lorentzenchr>` and
141+
`Roman Yurchak`_.
134142

135143
- |Enhancement| The parameter ``beta`` in :func:`metrics.fbeta_score` is
136144
updated to accept the zero and `float('+inf')` value.

sklearn/metrics/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@
6464
from .regression import mean_squared_log_error
6565
from .regression import median_absolute_error
6666
from .regression import r2_score
67+
from .regression import mean_tweedie_deviance
68+
from .regression import mean_poisson_deviance
69+
from .regression import mean_gamma_deviance
6770

6871

6972
from .scorer import check_scoring
@@ -110,6 +113,9 @@
110113
'mean_absolute_error',
111114
'mean_squared_error',
112115
'mean_squared_log_error',
116+
'mean_poisson_deviance',
117+
'mean_gamma_deviance',
118+
'mean_tweedie_deviance',
113119
'median_absolute_error',
114120
'multilabel_confusion_matrix',
115121
'mutual_info_score',

sklearn/metrics/regression.py

Lines changed: 187 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
# Manoj Kumar <manojkumarsivaraj334@gmail.com>
2020
# Michael Eickenberg <michael.eickenberg@gmail.com>
2121
# Konstantin Shmelkov <konstantin.shmelkov@polytechnique.edu>
22+
# Christian Lorentzen <lorentzen.ch@googlemail.com>
2223
# License: BSD 3 clause
2324

2425

2526
import numpy as np
27+
from scipy.special import xlogy
2628
import warnings
2729

2830
from ..utils.validation import (check_array, check_consistent_length,
@@ -38,11 +40,14 @@
3840
"mean_squared_log_error",
3941
"median_absolute_error",
4042
"r2_score",
41-
"explained_variance_score"
43+
"explained_variance_score",
44+
"mean_tweedie_deviance",
45+
"mean_poisson_deviance",
46+
"mean_gamma_deviance",
4247
]
4348

4449

45-
def _check_reg_targets(y_true, y_pred, multioutput):
50+
def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric"):
4651
"""Check that y_true and y_pred belong to the same regression task
4752
4853
Parameters
@@ -72,11 +77,13 @@ def _check_reg_targets(y_true, y_pred, multioutput):
7277
Custom output weights if ``multioutput`` is array-like or
7378
just the corresponding argument if ``multioutput`` is a
7479
correct keyword.
80+
dtype: str or list, default="numeric"
81+
the dtype argument passed to check_array
7582
7683
"""
7784
check_consistent_length(y_true, y_pred)
78-
y_true = check_array(y_true, ensure_2d=False)
79-
y_pred = check_array(y_pred, ensure_2d=False)
85+
y_true = check_array(y_true, ensure_2d=False, dtype=dtype)
86+
y_pred = check_array(y_pred, ensure_2d=False, dtype=dtype)
8087

8188
if y_true.ndim == 1:
8289
y_true = y_true.reshape((-1, 1))
@@ -609,3 +616,179 @@ def max_error(y_true, y_pred):
609616
if y_type == 'continuous-multioutput':
610617
raise ValueError("Multioutput not supported in max_error")
611618
return np.max(np.abs(y_true - y_pred))
619+
620+
621+
def mean_tweedie_deviance(y_true, y_pred, sample_weight=None, p=0):
622+
"""Mean Tweedie deviance regression loss.
623+
624+
Read more in the :ref:`User Guide <mean_tweedie_deviance>`.
625+
626+
Parameters
627+
----------
628+
y_true : array-like of shape (n_samples,)
629+
Ground truth (correct) target values.
630+
631+
y_pred : array-like of shape (n_samples,)
632+
Estimated target values.
633+
634+
sample_weight : array-like, shape (n_samples,), optional
635+
Sample weights.
636+
637+
p : float, optional
638+
Tweedie power parameter. Either p ≤ 0 or p ≥ 1.
639+
640+
The higher `p` the less weight is given to extreme
641+
deviations between true and predicted targets.
642+
643+
- p < 0: Extreme stable distribution. Requires: y_pred > 0.
644+
- p = 0 : Normal distribution, output corresponds to
645+
mean_squared_error. y_true and y_pred can be any real numbers.
646+
- p = 1 : Poisson distribution. Requires: y_true ≥ 0 and y_pred > 0.
647+
- 1 < p < 2 : Compound Poisson distribution. Requires: y_true ≥ 0
648+
and y_pred > 0.
649+
- p = 2 : Gamma distribution. Requires: y_true > 0 and y_pred > 0.
650+
- p = 3 : Inverse Gaussian distribution. Requires: y_true > 0
651+
and y_pred > 0.
652+
- otherwise : Positive stable distribution. Requires: y_true > 0
653+
and y_pred > 0.
654+
655+
Returns
656+
-------
657+
loss : float
658+
A non-negative floating point value (the best value is 0.0).
659+
660+
Examples
661+
--------
662+
>>> from sklearn.metrics import mean_tweedie_deviance
663+
>>> y_true = [2, 0, 1, 4]
664+
>>> y_pred = [0.5, 0.5, 2., 2.]
665+
>>> mean_tweedie_deviance(y_true, y_pred, p=1)
666+
1.4260...
667+
"""
668+
y_type, y_true, y_pred, _ = _check_reg_targets(
669+
y_true, y_pred, None, dtype=[np.float64, np.float32])
670+
if y_type == 'continuous-multioutput':
671+
raise ValueError("Multioutput not supported in mean_tweedie_deviance")
672+
check_consistent_length(y_true, y_pred, sample_weight)
673+
674+
if sample_weight is not None:
675+
sample_weight = column_or_1d(sample_weight)
676+
sample_weight = sample_weight[:, np.newaxis]
677+
678+
message = ("Mean Tweedie deviance error with p={} can only be used on "
679+
.format(p))
680+
if p < 0:
681+
# 'Extreme stable', y_true any realy number, y_pred > 0
682+
if (y_pred <= 0).any():
683+
raise ValueError(message + "strictly positive y_pred.")
684+
dev = 2 * (np.power(np.maximum(y_true, 0), 2-p)/((1-p) * (2-p)) -
685+
y_true * np.power(y_pred, 1-p)/(1-p) +
686+
np.power(y_pred, 2-p)/(2-p))
687+
elif p == 0:
688+
# Normal distribution, y_true and y_pred any real number
689+
dev = (y_true - y_pred)**2
690+
elif p < 1:
691+
raise ValueError("Tweedie deviance is only defined for p<=0 and "
692+
"p>=1.")
693+
elif p == 1:
694+
# Poisson distribution, y_true >= 0, y_pred > 0
695+
if (y_true < 0).any() or (y_pred <= 0).any():
696+
raise ValueError(message + "non-negative y_true and strictly "
697+
"positive y_pred.")
698+
dev = 2 * (xlogy(y_true, y_true/y_pred) - y_true + y_pred)
699+
elif p == 2:
700+
# Gamma distribution, y_true and y_pred > 0
701+
if (y_true <= 0).any() or (y_pred <= 0).any():
702+
raise ValueError(message + "strictly positive y_true and y_pred.")
703+
dev = 2 * (np.log(y_pred/y_true) + y_true/y_pred - 1)
704+
else:
705+
if p < 2:
706+
# 1 < p < 2 is Compound Poisson, y_true >= 0, y_pred > 0
707+
if (y_true < 0).any() or (y_pred <= 0).any():
708+
raise ValueError(message + "non-negative y_true and strictly "
709+
"positive y_pred.")
710+
else:
711+
if (y_true <= 0).any() or (y_pred <= 0).any():
712+
raise ValueError(message + "strictly positive y_true and "
713+
"y_pred.")
714+
715+
dev = 2 * (np.power(y_true, 2-p)/((1-p) * (2-p)) -
716+
y_true * np.power(y_pred, 1-p)/(1-p) +
717+
np.power(y_pred, 2-p)/(2-p))
718+
719+
return np.average(dev, weights=sample_weight)
720+
721+
722+
def mean_poisson_deviance(y_true, y_pred, sample_weight=None):
723+
"""Mean Poisson deviance regression loss.
724+
725+
Poisson deviance is equivalent to the Tweedie deviance with
726+
the power parameter `p=1`.
727+
728+
Read more in the :ref:`User Guide <mean_tweedie_deviance>`.
729+
730+
Parameters
731+
----------
732+
y_true : array-like of shape (n_samples,)
733+
Ground truth (correct) target values. Requires y_true ≥ 0.
734+
735+
y_pred : array-like of shape (n_samples,)
736+
Estimated target values. Requires y_pred > 0.
737+
738+
sample_weight : array-like, shape (n_samples,), optional
739+
Sample weights.
740+
741+
Returns
742+
-------
743+
loss : float
744+
A non-negative floating point value (the best value is 0.0).
745+
746+
Examples
747+
--------
748+
>>> from sklearn.metrics import mean_poisson_deviance
749+
>>> y_true = [2, 0, 1, 4]
750+
>>> y_pred = [0.5, 0.5, 2., 2.]
751+
>>> mean_poisson_deviance(y_true, y_pred)
752+
1.4260...
753+
"""
754+
return mean_tweedie_deviance(
755+
y_true, y_pred, sample_weight=sample_weight, p=1
756+
)
757+
758+
759+
def mean_gamma_deviance(y_true, y_pred, sample_weight=None):
760+
"""Mean Gamma deviance regression loss.
761+
762+
Gamma deviance is equivalent to the Tweedie deviance with
763+
the power parameter `p=2`. It is invariant to scaling of
764+
the target variable, and mesures relative errors.
765+
766+
Read more in the :ref:`User Guide <mean_tweedie_deviance>`.
767+
768+
Parameters
769+
----------
770+
y_true : array-like of shape (n_samples,)
771+
Ground truth (correct) target values. Requires y_true > 0.
772+
773+
y_pred : array-like of shape (n_samples,)
774+
Estimated target values. Requires y_pred > 0.
775+
776+
sample_weight : array-like, shape (n_samples,), optional
777+
Sample weights.
778+
779+
Returns
780+
-------
781+
loss : float
782+
A non-negative floating point value (the best value is 0.0).
783+
784+
Examples
785+
--------
786+
>>> from sklearn.metrics import mean_gamma_deviance
787+
>>> y_true = [2, 0.5, 1, 4]
788+
>>> y_pred = [0.5, 0.5, 2., 2.]
789+
>>> mean_gamma_deviance(y_true, y_pred)
790+
1.0568...
791+
"""
792+
return mean_tweedie_deviance(
793+
y_true, y_pred, sample_weight=sample_weight, p=2
794+
)

sklearn/metrics/scorer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
import numpy as np
2525

2626
from . import (r2_score, median_absolute_error, max_error, mean_absolute_error,
27-
mean_squared_error, mean_squared_log_error, accuracy_score,
27+
mean_squared_error, mean_squared_log_error,
28+
mean_tweedie_deviance, accuracy_score,
2829
f1_score, roc_auc_score, average_precision_score,
2930
precision_score, recall_score, log_loss,
3031
balanced_accuracy_score, explained_variance_score,
@@ -492,9 +493,15 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
492493
greater_is_better=False)
493494
neg_mean_absolute_error_scorer = make_scorer(mean_absolute_error,
494495
greater_is_better=False)
495-
496496
neg_median_absolute_error_scorer = make_scorer(median_absolute_error,
497497
greater_is_better=False)
498+
neg_mean_poisson_deviance_scorer = make_scorer(
499+
mean_tweedie_deviance, p=1., greater_is_better=False
500+
)
501+
502+
neg_mean_gamma_deviance_scorer = make_scorer(
503+
mean_tweedie_deviance, p=2., greater_is_better=False
504+
)
498505

499506
# Standard Classification Scores
500507
accuracy_scorer = make_scorer(accuracy_score)
@@ -542,6 +549,8 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
542549
neg_mean_absolute_error=neg_mean_absolute_error_scorer,
543550
neg_mean_squared_error=neg_mean_squared_error_scorer,
544551
neg_mean_squared_log_error=neg_mean_squared_log_error_scorer,
552+
neg_mean_poisson_deviance=neg_mean_poisson_deviance_scorer,
553+
neg_mean_gamma_deviance=neg_mean_gamma_deviance_scorer,
545554
accuracy=accuracy_scorer, roc_auc=roc_auc_scorer,
546555
roc_auc_ovr=roc_auc_ovr_scorer,
547556
roc_auc_ovo=roc_auc_ovo_scorer,

0 commit comments

Comments
 (0)
0