8000 [WIP] Add Generalized Linear Models (#9405) · scikit-learn/scikit-learn@eeebb3d · GitHub
[go: up one dir, main page]

Skip to content

Commit eeebb3d

Browse files
author
Christian Lorentzen
committed
[WIP] Add Generalized Linear Models (#9405)
* fix some bugs in user guide linear_model.rst * fix some pep8 issues in test_glm.py
1 parent 6f4d67c commit eeebb3d

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

doc/modules/linear_model.rst

+14-10
Original file line numberDiff line numberDiff line change
@@ -816,14 +816,14 @@ Generalized linear regression
816816
=============================
817817

818818
:class:`GeneralizedLinearRegressor` generalizes the :ref:`elastic_net` in two
819-
ways [1]_. First, the predicted values :math:`\hat{y}` are linked to a linear
819+
ways [8]_. First, the predicted values :math:`\hat{y}` are linked to a linear
820820
combination of the input variables :math:`X` via an inverse link function
821821
:math:`h` as
822822

823823
.. math:: \hat{y}(w, x) = h(xw) = h(w_0 + w_1 x_1 + ... + w_p x_p).
824824

825825
Secondly, the squared loss function is replaced by the deviance :math:`D` of an
826-
exponential dispersion model (EDM) [2]_. The objective function beeing minimized
826+
exponential dispersion model (EDM) [9]_. The objective function beeing minimized
827827
becomes
828828

829829
.. math:: \frac{1}{2s}D(y, \hat{y}) + \alpha \rho ||P_1w||_1
@@ -850,12 +850,16 @@ it is convenient to apply a link function different from the identity link
850850
:math:`h(Xw)=\exp(Xw)`.
851851

852852
Note that the feature matrix `X` should be standardized before fitting. This
853-
ensures that the penalty treats features equally.
853+
ensures that the penalty treats features equally. The estimator can be used as
854+
follows::
854855

855-
>>> from sklearn import linear_model
856-
>>> reg = linear_model.GeneralizedLinearRegressor(alpha=0.5, l1_ratio=0)
857-
>>> reg = linear_model.GeneralizedLinearRegressor(alpha=0.5, family='poisson', link='log')
856+
>>> from sklearn.linear_model import GeneralizedLinearRegressor
857+
>>> reg = GeneralizedLinearRegressor(alpha=0.5, family='poisson', link='log')
858858
>>> reg.fit([[0, 0], [0, 1], [2, 2]], [0, 1, 2])
859+
GeneralizedLinearRegressor(alpha=0.5, copy_X=True, family='poisson',
860+
fit_dispersion='chisqr', fit_intercept=True, l1_ratio=0,
861+
link='log', max_iter=100, solver='irls', start_params=None,
862+
tol=0.0001, verbose=0, warm_start=False)
859863
>>> reg.coef_
860864
array([ 0.24630255, 0.43373521])
861865
>>> reg.intercept_
@@ -905,7 +909,7 @@ Two remarks:
905909

906910
* The deviances for at least Normal, Poisson and Gamma distributions are
907911
strictly consistent scoring functions for the mean :math:`\mu`, see Eq.
908-
(19)-(20) in [3]_.
912+
(19)-(20) in [10]_.
909913

910914
* If you want to model a frequency, i.e. counts per exposure (time, volume, ...)
911915
you can do so by a Poisson distribution and passing
@@ -915,12 +919,12 @@ Two remarks:
915919

916920
.. topic:: References:
917921

918-
.. [1] McCullagh, Peter; Nelder, John (1989). Generalized Linear Models, Second Edition. Boca Raton: Chapman and Hall/CRC. ISBN 0-412-31760-5.
922+
.. [8] McCullagh, Peter; Nelder, John (1989). Generalized Linear Models, Second Edition. Boca Raton: Chapman and Hall/CRC. ISBN 0-412-31760-5.
919923
920-
.. [2] Jørgensen, B. (1992). The theory of exponential dispersion models and analysis of deviance. Monografias de matemática, no. 51.
924+
.. [9] Jørgensen, B. (1992). The theory of exponential dispersion models and analysis of deviance. Monografias de matemática, no. 51.
921925
See also `Exponential dispersion model. <https://en.wikipedia.org/wiki/Exponential_dispersion_model>`_
922926
923-
.. [3] Gneiting, T. (2010). `Making and Evaluating Point Forecasts. <https://arxiv.org/pdf/0912.0902.pdf>`_
927+
.. [10] Gneiting, T. (2010). `Making and Evaluating Point Forecasts. <https://arxiv.org/pdf/0912.0902.pdf>`_
924928
925929
Stochastic Gradient Descent - SGD
926930
=================================

sklearn/linear_model/tests/test_glm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from sklearn.linear_model.glm import (
44
Link,
5-
IdentityLink,
5+
# IdentityLink,
66
LogLink,
77
TweedieDistribution,
88
NormalDistribution, PoissonDistribution,
@@ -21,8 +21,9 @@ def test_link_properties():
2121
"""
2222
rng = np.random.RandomState(0)
2323
x = rng.rand(100)*100
24-
from sklearn.linear_model.glm import Link
25-
for link in vars()['Link'].__subclasses__():
24+
# from sklearn.linear_model.glm import Link
25+
# for link in vars()['Link'].__subclasses__():
26+
for link in Link.__subclasses__():
2627
link = link()
2728
assert_almost_equal(link.link(link.inverse(x)), x, decimal=10)
2829
assert_almost_equal(link.inverse_derivative(link.link(x)),

0 commit comments

Comments
 (0)
0