8000 [WIP] Add Generalized Linear Models (#9405) · scikit-learn/scikit-learn@6f4d67c · GitHub
[go: up one dir, main page]

Skip to content

Commit 6f4d67c

Browse files
author
Christian Lorentzen
committed
[WIP] Add Generalized Linear Models (#9405)
* added L2 penalty * api change: alpha, l1_ratio, P1, P2, warm_start, check_input, copy_X * added entry in user guide * improved docstrings * helper function _irls_step
1 parent f3da424 commit 6f4d67c

File tree

3 files changed

+721
-208
lines changed

3 files changed

+721
-208
lines changed

doc/modules/linear_model.rst

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,118 @@ loss.
810810
811811
.. [7] Aaron Defazio, Francis Bach, Simon Lacoste-Julien: `SAGA: A Fast Incremental Gradient Method With Support for Non-Strongly Convex Composite Objectives. <https://arxiv.org/abs/1407.0202>`_
812812
813+
.. _Generalized_linear_regression:
814+
815+
Generalized linear regression
816+
=============================
817+
818+
:class:`GeneralizedLinearRegressor` generalizes the :ref:`elastic_net` in two
819+
ways [1]_. First, the predicted values :math:`\hat{y}` are linked to a linear
820+
combination of the input variables :math:`X` via an inverse link function
821+
:math:`h` as
822+
823+
.. math:: \hat{y}(w, x) = h(xw) = h(w_0 + w_1 x_1 + ... + w_p x_p).
824+
825+
Secondly, the squared loss function is replaced by the deviance :math:`D` of an
826+
exponential dispersion model (EDM) [2]_. The objective function beeing minimized
827+
becomes
828+
829+
.. math:: \frac{1}{2s}D(y, \hat{y}) + \alpha \rho ||P_1w||_1
830+
+\frac{\alpha(1-\rho)}{2} w^T P_2 w
831+
832+
with sample weights :math:`s`.
833+
:math:`P_1` can be used to exclude some of the coefficients in the L1
834+
penalty, :math:`P_2` (must be positive semi-definite) allows for a more
835+
versatile L2 penalty.
836+
837+
Use cases, where a loss different from the squared loss might be appropriate,
838+
are the following:
839+
840+
* If the target values :math:`y` are counts (integer valued) or frequencies, you might try a Poisson deviance.
841+
842+
* If the target values are positive valued and skewed, you might try a Gamma deviance.
843+
844+
* If the target values seem to be heavy tailed, you might try an Inverse Gaussian deviance (or even higher variance power of the Tweedie family).
845+
846+
Since the linear predictor :math:`Xw` can be negative and
847+
Poisson, Gamma and Inverse Gaussian distributions don't have negative values,
848+
it is convenient to apply a link function different from the identity link
849+
:math:`h(x)=x` that guarantees the non-negativeness, e.g. the log-link with
850+
:math:`h(Xw)=\exp(Xw)`.
851+
852+
Note that the feature matrix `X` should be standardized before fitting. This
853+
ensures that the penalty treats features equally.
854+
855+
>>> from sklearn import linear_model
856+
>>> reg = linear_model.GeneralizedLinearRegressor(alpha=0.5, l1_ratio=0)
857+
>>> reg = linear_model.GeneralizedLinearRegressor(alpha=0.5, family='poisson', link='log')
858+
>>> reg.fit([[0, 0], [0, 1], [2, 2]], [0, 1, 2])
859+
>>> reg.coef_
860+
array([ 0.24630255, 0.43373521])
861+
>>> reg.intercept_
862+
-0.76383575123143277
863+
864+
Mathematical formulation
865+
------------------------
866+
867+
In the unpenalized case, the assumptions are the folowing:
868+
869+
* The target values :math:`y_i` are realizations of random variables
870+
:math:`Y_i \overset{i.i.d}{\sim} \mathrm{EDM}(\mu_i, \frac{\phi}{s_i})`
871+
with expectation :math:`\mu_i=\mathrm{E}[Y]`, dispersion parameter
872+
:math:`\phi` and sample weights :math:`s_i`.
873+
* The aim is to predict the expectation :math:`\mu_i` with
874+
:math:`\hat{y_i} = h(\eta_i)`, linear predictor
875+
:math:`\eta_i=(Xw)_i` and inverse link function :math:`h(\eta)`.
876+
877+
Note that the first assumption implies
878+
:math:`\mathrm{Var}[Y_i]=\frac{\phi}{s_i} v(\mu_i)` with unit variance
879+
function :math:`v(\mu)`. Specifying a particular distribution of an EDM is the
880+
same as specifying a unit variance function (they are one-to-one).
881+
882+
Including penalties helps to avoid overfitting or, in case of L1 penalty, to
883+
obtain sparse solutions. But there are also other motivations to include them,
884+
e.g. accounting fo dependence structure of :math:`y`.
885+
886+
The objective function, which is independent of :math:`\phi`, is minimized with
887+
respect to the coefficients :math:`w`.
888+
889+
The deviance is defined by
890+
891+
.. math:: D(y, \mu) = -2\phi\cdot
892+
\left(loglike(y,\mu,\frac{\phi}{s})
893+
- loglike(y,y,\frac{\phi}{s})\right)
894+
895+
===================================== =================================
896+
Distribution Variance Function :math:`v(\mu)`
897+
===================================== =================================
898+
Normal ("normal") :math:`1`
899+
Poisson ("poisson") :math:`\mu`
900+
Gamma ("gamma") :math:`\mu^2`
901+
Inverse Gaussian ("inverse.gaussian") :math:`\mu^3`
902+
===================================== =================================
903+
904+
Two remarks:
905+
906+
* The deviances for at least Normal, Poisson and Gamma distributions are
907+
strictly consistent scoring functions for the mean :math:`\mu`, see Eq.
908+
(19)-(20) in [3]_.
909+
910+
* If you want to model a frequency, i.e. counts per exposure (time, volume, ...)
911+
you can do so by a Poisson distribution and passing
912+
:math:`y=\frac{\mathrm{counts}}{\mathrm{exposure}}` as target values together
913+
with :math:`s=\mathrm{exposure}` as sample weights.
914+
915+
916+
.. topic:: References:
917+
918+
.. [1] McCullagh, Peter; Nelder, John (1989). Generalized Linear Models, Second Edition. Boca Raton: Chapman and Hall/CRC. ISBN 0-412-31760-5.
919+
920+
.. [2] Jørgensen, B. (1992). The theory of exponential dispersion models and analysis of deviance. Monografias de matemática, no. 51.
921+
See also `Exponential dispersion model. <https://en.wikipedia.org/wiki/Exponential_dispersion_model>`_
922+
923+
.. [3] Gneiting, T. (2010). `Making and Evaluating Point Forecasts. <https://arxiv.org/pdf/0912.0902.pdf>`_
924+
813925
Stochastic Gradient Descent - SGD
814926
=================================
815927

0 commit comments

Comments
 (0)
0