8000 ENH add newton-cholesky solver to LogisticRegression (#24767) · andportnoy/scikit-learn@281f4d9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 281f4d9

Browse files
lorentzenchrandportnoy
authored andcommitted
ENH add newton-cholesky solver to LogisticRegression (scikit-learn#24767)
1 parent 491ead9 commit 281f4d9

File tree

5 files changed

+260
-185
lines changed

5 files changed

+260
-185
lines changed

doc/modules/linear_model.rst

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ Solvers
954954
-------
955955

956956
The solvers implemented in the class :class:`LogisticRegression`
957-
are "liblinear", "newton-cg", "lbfgs", "sag" and "saga":
957+
are "lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag" and "saga":
958958

959959
The solver "liblinear" uses a coordinate descent (CD) algorithm, and relies
960960
on the excellent C++ `LIBLINEAR library
@@ -968,7 +968,7 @@ classifiers. For :math:`\ell_1` regularization :func:`sklearn.svm.l1_min_c` allo
968968
calculate the lower bound for C in order to get a non "null" (all feature
969969
weights to zero) model.
970970

971-
The "lbfgs", "sag" and "newton-cg" solvers only support :math:`\ell_2`
971+
The "lbfgs", "newton-cg" and "sag" solvers only support :math:`\ell_2`
972972
regularization or no regularization, and are found to converge faster for some
973973
high-dimensional data. Setting `multi_class` to "multinomial" with these solvers
974974
learns a true multinomial logistic regression model [5]_, which means that its
@@ -989,33 +989,41 @@ Broyden–Fletcher–Goldfarb–Shanno algorithm [8]_, which belongs to
989989
quasi-Newton methods. The "lbfgs" solver is recommended for use for
990990
small data-sets but for larger datasets its performance suffers. [9]_
991991

992+
The "newton-cholesky" solver is an exact Newton solver that calculates the hessian
993+
matrix and solves the resulting linear system. It is a very good choice for
994+
`n_samples` >> `n_features`, but has a few shortcomings: Only :math:`\ell_2`
995+
regularization is supported. Furthermore, because the hessian matrix is explicitly
996+
computed, the memory usage has a quadratic dependency on `n_features` as well as on
997+
`n_classes`. As a consequence, only the one-vs-rest scheme is implemented for the
998+
multiclass case.
999+
9921000
The following table summarizes the penalties supported by each solver:
9931001

994-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
995-
| | **Solvers** |
996-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
997-
| **Penalties** | **'liblinear'** | **'lbfgs'** | **'newton-cg'** | **'sag'** | **'saga'** |
998-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
999-
| Multinomial + L2 penalty | no | yes | yes | yes | yes |
1000-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
1001-
| OVR + L2 penalty | yes | yes | yes | yes | yes |
1002-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
1003-
| Multinomial + L1 penalty | no | no | no | no | yes |
1004-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
1005-
| OVR + L1 penalty | yes | no | no | no | yes |
1006-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
1007-
| Elastic-Net | no | no | no | no | yes |
1008-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
1009-
| No penalty ('none') | no | yes | yes | yes | yes |
1010-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
1011-
| **Behaviors** | |
1012-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
1013-
| Penalize the intercept (bad) | yes | no | no | no | no |
1014-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
1015-
| Faster for large datasets | no | no | no | yes | yes |
1016-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
1017-
| Robust to unscaled datasets | yes | yes | yes | no | no |
1018-
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
1002+
+------------------------------+-----------------+-------------+-----------------+-----------------------+-----------+------------+
1003+
| | **Solvers** |
1004+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1005+
| **Penalties** | **'lbfgs'** | **'liblinear'** | **'newton-cg'** | **'newton-cholesky'** | **'sag'** | **'saga'** |
1006+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1007+
| Multinomial + L2 penalty | yes | no | yes | no | yes | yes |
1008+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1009+
| OVR + L2 penalty | yes | yes | yes | yes | yes | yes |
1010+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1011+
| Multinomial + L1 penalty | no | no | no | no | no | yes |
1012+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1013+
| OVR + L1 penalty | no | yes | no | no | no | yes |
1014+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1015+
| Elastic-Net | no | no | no | no | no | yes |
1016+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1017+
| No penalty ('none') | yes | no | yes | yes | yes | yes |
1018+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1019+
| **Behaviors** | |
1020+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1021+
| Penalize the intercept (bad) | no | yes | no | no | no | no |
1022+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1023+
| Faster for large datasets | no | no | no | no | yes | yes |
1024+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
1025+
| Robust to unscaled datasets | yes | yes | yes | yes | no | no |
1026+
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
10191027

10201028
The "lbfgs" solver is used by default for its robustness. For large datasets
10211029
the "saga" solver is usually faster.

doc/whats_new/v1.2.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,15 +365,16 @@ Changelog
365365
:mod:`sklearn.linear_model`
366366
...........................
367367

368-
- |Enhancement| :class:`linear_model.GammaRegressor`,
368+
- |Enhancement| :class:`linear_model.LogisticRegression`,
369+
:class:`linear_model.LogisticRegressionCV`, :class:`linear_model.GammaRegressor`,
369370
:class:`linear_model.PoissonRegressor` and :class:`linear_model.TweedieRegressor` got
370371
a new solver `solver="newton-cholesky"`. This is a 2nd order (Newton) optimisation
371372
routine that uses a Cholesky decomposition of the hessian matrix.
372373
When `n_samples >> n_features`, the `"newton-cholesky"` solver has been observed to
373374
converge both faster and to a higher precision solution than the `"lbfgs"` solver on
374375
problems with one-hot encoded categorical variables with some rare categorical
375376
levels.
376-
:pr:`24637` by :user:`Christian Lorentzen <lorentzenchr>`.
377+
:pr:`24637` and :pr:`24767` by :user:`Christian Lorentzen <lorentzenchr>`.
377378

378379
- |Enhancement| :class:`linear_model.GammaRegressor`,
379380
:class:`linear_model.PoissonRegressor` and :class:`linear_model.TweedieRegressor`

sklearn/linear_model/_glm/glm.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ class _GeneralizedLinearRegressor(RegressorMixin, BaseEstimator):
7575
'newton-cholesky'
7676
Uses Newton-Raphson steps (in arbitrary precision arithmetic equivalent to
7777
iterated reweighted least squares) with an inner Cholesky based solver.
78-
This solver is suited for n_samples >> n_features.
78+
This solver is a good choice for `n_samples` >> `n_features`, especially
79+
with one-hot encoded categorical features with rare categories. Be aware
80+
that the memory usage of this solver has a quadratic dependency on
81+
`n_features` because it explicitly computes the Hessian matrix.
7982
8083
.. versionadded:: 1.2
8184
@@ -304,7 +307,7 @@ def fit(self, X, y, sample_weight=None):
304307
coef = sol.solve(X, y, sample_weight)
305308
self.n_iter_ = sol.iteration
306309
else:
307-
raise TypeError(f"Invalid solver={self.solver}.")
310+
raise ValueError(f"Invalid solver={self.solver}.")
308311

309312
if self.fit_intercept:
310313
self.intercept_ = coef[-1]
@@ -512,7 +515,10 @@ class PoissonRegressor(_GeneralizedLinearRegressor):
512515
'newton-cholesky'
513516
Uses Newton-Raphson steps (in arbitrary precision arithmetic equivalent to
514517
iterated reweighted least squares) with an inner Cholesky based solver.
515-
This solver is suited for n_samples >> n_features.
518+
This solver is a good choice for `n_samples` >> `n_features`, especially
519+
with one-hot encoded categorical features with rare categories. Be aware
520+
that the memory usage of this solver has a quadratic dependency on
521+
`n_features` because it explicitly computes the Hessian matrix.
516522
517523
.. versionadded:: 1.2
518524
@@ -640,7 +646,10 @@ class GammaRegressor(_GeneralizedLinearRegressor):
640646
'newton-cholesky'
641647
Uses Newton-Raphson steps (in arbitrary precision arithmetic equivalent to
642648
iterated reweighted least squares) with an inner Cholesky based solver.
643-
This solver is suited for n_samples >> n_features.
649+
This solver is a good choice for `n_samples` >> `n_features`, especially
650+
with one-hot encoded categorical features with rare categories. Be aware
651+
that the memory usage of this solver has a quadratic dependency on
652+
`n_features` because it explicitly computes the Hessian matrix.
644653
645654
.. versionadded:: 1.2
646655
@@ -799,7 +808,10 @@ class TweedieRegressor(_GeneralizedLinearRegressor):
799808
'newton-cholesky'
800809
Uses Newton-Raphson steps (in arbitrary precision arithmetic equivalent to
801810
iterated reweighted least squares) with an inner Cholesky based solver.
802-
This solver is suited for n_samples >> n_features.
811+
This solver is a good choice for `n_samples` >> `n_features`, especially
812+
with one-hot encoded categorical features with rare categories. Be aware
813+
that the memory usage of this solver has a quadratic dependency on
814+
`n_features` because it explicitly computes the Hessian matrix.
803815
804816
.. versionadded:: 1.2
805817

0 commit comments

Comments
 (0)
0