8000 ENH Minimal Generalized linear models implementation (L2 + lbfgs) (#1… · scikit-learn/scikit-learn@69ea066 · GitHub
[go: up one dir, main page]

Skip to content

Commit 69ea066

Browse files
authored
ENH Minimal Generalized linear models implementation (L2 + lbfgs) (#14300)
Co-authored-by: Christian Lorentzen <lorentzen.ch@googlemail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
1 parent 60b8fb2 commit 69ea066

21 files changed

+2925
-57
lines changed

doc/modules/classes.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,21 @@ Any estimator using the Huber loss would also be robust to outliers, e.g.
837837
linear_model.RANSACRegressor
838838
linear_model.TheilSenRegressor
839839

840+
Generalized linear models (GLM) for regression
841+
----------------------------------------------
842+
843+
These models allow for response variables to have error distributions other
844+
than a normal distribution:
845+
846+
.. autosummary::
847+
:toctree: generated/
848+
:template: class.rst
849+
850+
linear_model.PoissonRegressor
851+
linear_model.TweedieRegressor
852+
linear_model.GammaRegressor
853+
854+
840855
Miscellaneous
841856
-------------
842857

Loading

doc/modules/linear_model.rst

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,13 +556,13 @@ orthogonal matching pursuit can approximate the optimum solution vector with a
556556
fixed number of non-zero elements:
557557

558558
.. math::
559-
\underset{\gamma}{\operatorname{arg\,min\,}} ||y - X\gamma||_2^2 \text{ subject to } ||\gamma||_0 \leq n_{\text{nonzero\_coefs}}
559+
\underset{w}{\operatorname{arg\,min\,}} ||y - Xw||_2^2 \text{ subject to } ||w||_0 \leq n_{\text{nonzero\_coefs}}
560560
561561
Alternatively, orthogonal matching pursuit can target a specific error instead
562562
of a specific number of non-zero coefficients. This can be expressed as:
563563

564564
.. math::
565-
\underset{\gamma}{\operatorname{arg\,min\,}} ||\gamma||_0 \text{ subject to } ||y-X\gamma||_2^2 \leq \text{tol}
565+
\underset{w}{\operatorname{arg\,min\,}} ||w||_0 \text{ subject to } ||y-Xw||_2^2 \leq \text{tol}
566566
567567
568568
OMP is based on a greedy algorithm that includes at each step the atom most
@@ -906,7 +906,7 @@ with 'log' loss, which might be even faster but requires more tuning.
906906
It is possible to obtain the p-values and confidence intervals for
907907
coefficients in cases of regression without penalization. The `statsmodels
908908
package <https://pypi.org/project/statsmodels/>` natively supports this.
909-
Within sklearn, one could use bootstrapping instead as well.
909+
Within sklearn, one could use bootstrapping instead as well.
910910

911911

912912
:class:`LogisticRegressionCV` implements Logistic Regression with built-in
@@ -928,6 +928,149 @@ to warm-starting (see :term:`Glossary <warm_start>`).
928928
.. [9] `"Performance Evaluation of Lbfgs vs other solvers"
929929
<http://www.fuzihao.org/blog/2016/01/16/Comparison-of-Gradient-Descent-Stochastic-Gradient-Descent-and-L-BFGS/>`_
930930
931+
.. _Generalized_linear_regression:
932+
933+
Generalized Linear Regression
934+
=============================
935+
936+
Generalized Linear Models (GLM) extend linear models in two ways
937+
[10]_. First, the predicted values :math:`\hat{y}` are linked to a linear
938+
combination of the input variables :math:`X` via an inverse link function
939+
:math:`h` as
940+
941+
.. math:: \hat{y}(w, X) = h(Xw).
942+
943+
Secondly, the squared loss function is replaced by the unit deviance
944+
:math:`d` of a distribution in the exponential family (or more precisely, a
945+
reproductive exponential dispersion model (EDM) [11]_).
946+
947+
The minimization problem becomes:
948+
949+
.. math:: \min_{w} \frac{1}{2 n_{\text{samples}}} \sum_i d(y_i, \hat{y}_i) + \frac{\alpha}{2} ||w||_2,
950+
951+
where :math:`\alpha` is the L2 regularization penalty. When sample weights are
952+
provided, the average becomes a weighted average.
953+
954+
The following table lists some specific EDMs and their unit deviance (all of
955+
these are instances of the Tweedie family):
956+
957+
================= =============================== ============================================
958+
Distribution Target Domain Unit Deviance :math:`d(y, \hat{y})`
959+
================= =============================== ============================================
960+
Normal :math:`y \in (-\infty, \infty)` :math:`(y-\hat{y})^2`
961+
Poisson :math:`y \in [0, \infty)` :math:`2(y\log\frac{y}{\hat{y}}-y+\hat{y})`
962+
Gamma :math:`y \in (0, \infty)` :math:`2(\log\frac{\hat{y}}{y}+\frac{y}{\hat{y}}-1)`
963+
Inverse Gaussian :math:`y \in (0, \infty)` :math:`\frac{(y-\hat{y})^2}{y\hat{y}^2}`
964+
================= =============================== ============================================
965+
966+
The Probability Density Functions (PDF) of these distributions are illustrated
967+
in the following figure,
968+
969+
.. figure:: ./glm_data/poisson_gamma_tweedie_distributions.png
970+
:align: center
971+
:scale: 100%
972+
973+
PDF of a random variable Y following Poisson, Tweedie (power=1.5) and Gamma
974+
distributions with different mean values (:math:`\mu`). Observe the point
975+
mass at :math:`Y=0` for the Poisson distribution and the Tweedie (power=1.5)
976+
distribution, but not for the Gamma distribution which has a strictly
977+
positive target domain.
978+
979+
The choice of the distribution depends on the problem at hand:
980+
981+
* If the target values :math:`y` are counts (non-negative integer valued) or
982+
relative frequencies (non-negative), you might use a Poisson deviance
983+
with log-link.
984+
* If the target values are positive valued and skewed, you might try a
985+
Gamma deviance with log-link.
986+
* If the target values seem to be heavier tailed than a Gamma distribution,
987+
you might try an Inverse Gaussian deviance (or even higher variance powers
988+
of the Tweedie family).
989+
990+
991+
Examples of use cases include:
992+
993+
* Agriculture / weather modeling: number of rain events per year (Poisson),
994+
amount of rainfall per event (Gamma), total rainfall per year (Tweedie /
995+
Compound Poisson Gamma).
996+
* Risk modeling / insurance policy pricing: number of claim events /
997+
policyholder per year (Poisson), cost per event (Gamma), total cost per
998+
policyholder per year (Tweedie / Compound Poisson Gamma).
999+
* Predictive maintenance: number of production interruption events per year:
1000+
Poisson, duration of interruption: Gamma, total interruption time per year
1001+
(Tweedie / Compound Poisson Gamma).
1002+
1003+
1004+
.. topic:: References:
1005+
1006+
.. [10] McCullagh, Peter; Nelder, John (1989). Generalized Linear Models,
1007+
Second Edition. Boca Raton: Chapman and Hall/CRC. ISBN 0-412-31760-5.
1008+
1009+
.. [11] Jørgensen, B. (1992). The theory of exponential dispersion models
1010+
and analysis of deviance. Monografias de matemática, no. 51. See also
1011+
`Exponential dispersion model.
1012+
<https://en.wikipedia.org/wiki/Exponential_dispersion_model>`_
1013+
1014+
Usage
1015+
-----
1016+
1017+
:class:`TweedieRegressor` implements a generalized linear model for the
1018+
Tweedie distribution, that allows to model any of the above mentioned
1019+
distributions using the appropriate ``power`` parameter. In particular:
1020+
1021+
- ``power = 0``: Normal distribution. Specific estimators such as
1022+
:class:`Ridge`, :class:`ElasticNet` are generally more appropriate in
1023+
this case.
1024+
- ``power = 1``: Poisson distribution. :class:`PoissonRegressor` is exposed
1025+
for convenience. However, it is strictly equivalent to
1026+
`TweedieRegressor(power=1, link='log')`.
1027+
- ``power = 2``: Gamma distribution. :class:`GammaRegressor` is exposed for
1028+
convenience. However, it is strictly equivalent to
1029+
`TweedieRegressor(power=2, link='log')`.
1030+
- ``power = 3``: Inverse Gaussian distribution.
1031+
1032+
The link function is determined by the `link` parameter.
1033+
1034+
Usage example::
1035+
1036+
>>> from sklearn.linear_model import TweedieRegressor
1037+
>>> reg = TweedieRegressor(power=1, alpha=0.5, link='log')
1038+
>>> reg.fit([[0, 0], [0, 1], [2, 2]], [0, 1, 2])
1039+
TweedieRegressor(alpha=0.5, link='log', power=1)
1040+
>>> reg.coef_
1041+
array([0.2463..., 0.4337...])
1042+
>>> reg.intercept_
1043+
-0.7638...
1044+
1045+
1046+
.. topic:: Examples:
1047+
1048+
* :ref:`sphx_glr_auto_examples_linear_model_plot_poisson_regression_non_normal_loss.py`
1049+
* :ref:`sphx_glr_auto_examples_linear_model_plot_tweedie_regression_insurance_claims.py`
1050+
1051+
Practical considerations
1052+
------------------------
1053+
1054+
The feature matrix `X` should be standardized before fitting. This ensures
1055+
that the penalty treats features equally.
1056+
1057+
Since the linear predictor :math:`Xw` can be negative and Poisson,
1058+
Gamma and Inverse Gaussian distributions don't support negative values, it
1059+
is necessary to apply an inverse link function that guarantees the
1060+
non-negativeness. For example with `link='log'`, the inverse link function
1061+
becomes :math:`h(Xw)=\exp(Xw)`.
1062+
1063+
If you want to model a relative frequency, i.e. counts per exposure (time,
1064+
volume, ...) you can do so by using a Poisson distribution and passing
1065+
:math:`y=\frac{\mathrm{counts}}{\mathrm{exposure}}` as target values
1066+
together with :math:`\mathrm{exposure}` as sample weights. For a concrete
1067+
example see e.g.
1068+
:ref:`sphx_glr_auto_examples_linear_model_plot_tweedie_regression_insurance_claims.py`.
1069+
1070+
When performing cross-validation for the `power` parameter of
1071+
`TweedieRegressor`, it is advisable to specify an explicit `scoring` function,
1072+
because the default scorer :meth:`TweedieRegressor.score` is a function of
1073+
`power` itself.
9311074

9321075
Stochastic Gradient Descent - SGD
9331076
=================================

doc/whats_new/v0.23.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,13 @@ Changelog
221221
:mod:`sklearn.linear_model`
222222
...........................
223223

224+
- |MajorFeature| Added generalized linear models (GLM) with non normal error
225+
distributions, including :class:`linear_model.PoissonRegressor`,
226+
:class:`linear_model.GammaRegressor` and :class:`linear_model.TweedieRegressor`
227+
which use Poisson, Gamma and Tweedie distributions respectively.
228+
:pr:`14300` by :user:`Christian Lorentzen <lorentzenchr>`, `Roman Yurchak`_,
229+
and `Olivier Grisel`_.
230+
224231
- |Feature| Support of `sample_weight` in :class:`linear_model.ElasticNet` and
225232
:class:`linear_model:Lasso` for dense feature matrix `X`.
226233
:pr:`15436` by :user:`Christian Lorentzen <lorentzenchr>`.

0 commit comments

Comments
 (0)
0