8000 FEA add (single) Cholesky Newton solver to GLMs by lorentzenchr · Pull Request #24637 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FEA add (single) Cholesky Newton solver to GLMs #24637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 115 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
d4a8b4e
FEA add NewtonSolver, CholeskyNewtonSolver and QRCholeskyNewtonSolver
lorentzenchr May 9, 2022
ad65d89
Merge branch 'main' into glm_newton_cholesky
lorentzenchr May 9, 2022
267d570
ENH better singular hessian special solve
lorentzenchr May 11, 2022
dd5a820
CLN fix some typos found by reviewer
lorentzenchr May 11, 2022
bf1828d
TST assert ConvergenceWarning is raised
lorentzenchr May 11, 2022
9783e6b
MNT add BaseCholeskyNewtonSolver
lorentzenchr May 25, 2022
d373e63
WIP colinear design in GLMs
ogrisel May 25, 2022
bcc176e
Merge branch 'glm_newton_cholesky' into glm-newton-colinear
ogrisel May 30, 2022
c6efcef
FIX _solve_singular
lorentzenchr May 31, 2022
2381298
FIX solver for singular cases in GLMs (#5)
lorentzenchr May 31, 2022
d2063f7
FIX false unpacking in
lorentzenchr Jun 2, 2022
8d1137f
Merge branch 'main' into glm_newton_cholesky
lorentzenchr Jun 2, 2022
e6684c6
TST add tests for unpenalized GLMs
lorentzenchr Jun 3, 2022
3fb3695
TST fix solutions of glm_dataset < 8000 /div> lorentzenchr Jun 10, 2022
2b6485e
ENH add SVDFallbackSolver
lorentzenchr Jun 10, 2022
59989f3
CLN remove SVDFallbackSolver
lorentzenchr Jun 11, 2022
d463817
ENH use gradient step for singular hessians
lorentzenchr Jun 11, 2022
8a108bb
ENH print iteration number in warnings
lorentzenchr Jun 11, 2022
82287af
TST improve test_linalg_warning_with_newton_solver
lorentzenchr Jun 11, 2022
9868a13
CLN LinAlgWarning fron scipy.linalg
lorentzenchr Jun 11, 2022
e3a2627
ENH more robust hessian
lorentzenchr Jun 12, 2022
0276cd9
ENH increase maxls for lbfgs to make it more robust
lorentzenchr Jun 13, 2022
b429472
ENH add hessian_warning for too many negative hessian values
lorentzenchr Jun 13, 2022
a85f251
CLN some warning messages
lorentzenchr Jun 13, 2022
c9b1200
ENH add lbfgs_step
lorentzenchr Jun 13, 2022
2f0ea15
ENH use lbfgs_step for hessian_warning
lorentzenchr Jun 13, 2022
9ce6cf2
TST make them pass
lorentzenchr Jun 13, 2022
221f611
TST tweek rtol for lbfgs
lorentzenchr Jun 13, 2022
aa81fb5
TST add rigoros test for GLMs
lorentzenchr Jun 13, 2022
cd06ba7
TST improve test_warm_start
lorentzenchr Jun 13, 2022
a27c7f9
ENH improve lbfgs options for better convergence
lorentzenchr Jun 14, 2022
a5c1fc0
CLN fix test_warm_start
lorentzenchr Jun 14, 2022
9b8519d
TST fix assert singular values in datasets
lorentzenchr Jun 15, 2022
68947b6
CLN address most review comments
lorentzenchr Jun 15, 2022
06a0b79
ENH enable more vebosity levels for lbfgs
lorentzenchr Jun 15, 2022
4d245cf
DOC add whatsnew
lorentzenchr Jun 15, 2022
f6bee64
Merge branch 'main' into glm_tests
lorentzenchr Jun 15, 2022
382d177
CLN remove xfail and clean a bit
lorentzenchr Jun 16, 2022
25fe6e1
CLN docstring about minimum norm
lorentzenchr Jun 16, 2022
4c4582d
More informative repr for the glm_dataset fixture cases
ogrisel Jun 16, 2022
2065a9e
Forgot to run black
ogrisel Jun 16, 2022
5aaaf21
CLN remove unnecessary filterwarnings
lorentzenchr Jun 17, 2022
10da880
CLN address review comments
lorentzenchr Jun 17, 2022
b98273b
Merge branch 'glm_tests' of https://github.com/lorentzenchr/scikit-le…
lorentzenchr Jun 17, 2022
e16d04e
Trigger [all random seeds] on the following tests:
ogrisel Jun 17, 2022
2fa4397
CLN add comment for lbfgs ftol=64 * machine precision
lorentzenchr Jun 17, 2022
bafc7e7
Merge branch 'glm_tests' of https://github.com/lorentzenchr/scikit-le…
lorentzenchr Jun 17, 2022
c0e2422
CLN XXX code comment
lorentzenchr Jun 17, 2022
1149342
Trigger [all random seeds] on the following tests:
lorentzenchr Jun 17, 2022
3dad445
CLN link issue and remove code snippet in comment
lorentzenchr Jun 17, 2022
12525f1
Trigger [all random seeds] on the following tests:
lorentzenchr Jun 17, 2022
556164a
CLN add catch_warnings
lorentzenchr Jun 17, 2022
4fcc1c8
Trigger [all random seeds] on the following tests:
lorentzenchr Jun 17, 2022
3569991
Merge branch 'main' into glm_tests
lorentzenchr Jun 17, 2022
c723f65
Trigger [all random seeds] on the following tests:
lorentzenchr Jun 17, 2022
3458c39
[all random seeds]
lorentzenchr Jun 18, 2022
99f4cf9
Trigger with -Werror [all random seeds]
lorentzenchr Jun 18, 2022
79ec862
ENH increase maxls to 50
lorentzenchr Jun 18, 2022
904e960
[all random seeds]
lorentzenchr Jun 18, 2022
4fd1d9b
Revert "Trigger with -Werror [all random seeds]"
lorentzenchr Jun 18, 2022
352b7c5
Merge branch 'glm_tests' into glm_newton_cholesky
lorentzenchr Jun 18, 2022
81efa1a
TST add catch_warnings to filterwarnings
lorentzenchr Jun 18, 2022
fa7469c
TST adapt tests for newton solvers
lorentzenchr Jun 18, 2022
ccb9866
CLN cleaner gradient step with gradient_times_newton
lorentzenchr Jun 19, 2022
28f2051
DOC add whatsnew
lorentzenchr Jun 19, 2022
2d9f205
ENH always use lbfgs as fallback
lorentzenchr Jun 19, 2022
e70a4df
TST adapt rtol
lorentzenchr Jun 19, 2022
85a1c52
TST fix test_linalg_warning_with_newton_solver
lorentzenchr Jun 20, 2022
0a557ca
CLN address some review comments
lorentzenchr Jun 20, 2022
be2fe6d
Improve tests related to convergence warning on collinear data
ogrisel Jun 29, 2022
0906f94
overfit -> fit
ogrisel Jun 30, 2022
0aa83ac
Typo in comment
ogrisel Jun 30, 2022
7ecfc45
Merge branch 'main' into glm_newton_cholesky
ogrisel Jun 30, 2022
325c849
Apply suggestions from code review
ogrisel Jun 30, 2022
eecd8e2
Merge remote-tracking branch 'origin/main' into glm_newton_cholesky
ogrisel Jul 1, 2022
d4206d6
ENH fallback_lbfgs_solve
lorentzenchr Jul 1, 2022
5e6aa99
ENH adapt rtol
lorentzenchr Jul 1, 2022
4992398
Merge branch 'glm_newton_cholesky' into test-glm-collinear-data
ogrisel Jul 1, 2022
15192f1
Improve test_linalg_warning_with_newton_solver
ogrisel Jul 1, 2022
621ffd8
Better comments
ogrisel Jul 1, 2022
83944aa
Merge pull request #8 from ogrisel/test-glm-collinear-data
ogrisel Jul 1, 2022
6413f07
Fixed Hessian casing and improved warning messages
ogrisel Jul 1, 2022
bfe3c38
[all random seeds]
ogrisel Jul 1, 2022
fa9e885
Ignore ConvergenceWarnings for now if convergence is good
ogrisel Jul 1, 2022
7318a4f
CLN remove counting of warnings
lorentzenchr Jul 2, 2022
34e297e
ENH fall back to lbfgs if line search did not converge
lorentzenchr Jul 2, 2022
d8c98a2
DOC better comment on performance bottleneck
lorentzenchr Jul 2, 2022
c0ec17d
Update GLM related examples to use the new solver
ogrisel Jul 5, 2022
bcf98af
Merge branch 'main' into glm_newton_cholesky
ogrisel Jul 5, 2022
a3b5f83
Merge branch 'main' into glm_newton_cholesky
jjerphan Aug 1, 2022
0d698d0
CLN address reviewer comments
lorentzenchr Sep 15, 2022
55cd86b
Merge branch 'main' into glm_newton_cholesky
lorentzenchr Oct 5, 2022
beeb774
EXA improve some wordings
lorentzenchr Oct 5, 2022
7c46dd8
CLN do not pop "solver in parameter constraints
lorentzenchr Oct 8, 2022
41e7c42
CLN fix typos
lorentzenchr Oct 9, 2022
9097536
DOC fix docstring
lorentzenchr Oct 9, 2022
a173124
CLN remove solver newton-qr-cholesky
lore 8000 ntzenchr Oct 11, 2022
049a2fc
DOC update PR number in whatsnew
lorentzenchr Oct 11, 2022
f624d61
Merge branch 'main' into glm_newton_cholesky_only
lorentzenchr Oct 14, 2022
f225453
CLN address review comments
lorentzenchr Oct 14, 2022
28b3820
CLN remove unnecessary catch_warnings
lorentzenchr Oct 14, 2022
46841bd
CLN address some review comments
lorentzenchr Oct 18, 2022
02c4245
DOC more precise whatsnew
lorentzenchr Oct 23, 2022
f841e54
CLN use init_zero_coef
lorentzenchr Oct 23, 2022
e285f05
CLN use and test init_zero_coef
lorentzenchr Oct 23, 2022
55e57df
CLN address some review comments
lorentzenchr Oct 23, 2022
1d158cb
CLN mark NewtonSolver as private by leading underscore
lorentzenchr Oct 23, 2022
a30c71f
CLN exact comments for inner_solve
lorentzenchr Oct 23, 2022
298ce60
TST add test_newton_solver_verbosity
lorentzenchr Oct 23, 2022
00f7465
TST extend test_newton_solver_verbosity
lorentzenchr Oct 23, 2022
308fd88
TST logic in test_glm_regression_unpenalized
lorentzenchr Oct 24, 2022
ebf930b
TST use count_nonzero
lorentzenchr Oct 24, 2022
2ffa621
Merge branch 'main' into glm_newton_cholesky_only
lorentzenchr Oct 24, 2022
d304ce9
CLN remove super rare line search checks
lorentzenchr Oct 24, 2022
f002eb7
MNT move Newton solver to new file _newton_solver.py
lorentzenchr Oct 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,16 @@ Changelog
:mod:`sklearn.linear_model`
...........................

- |Enhancement| :class:`linear_model.GammaRegressor`,
:class:`linear_model.PoissonRegressor` and :class:`linear_model.TweedieRegressor` got
a new solver `solver="newton-cholesky"`. This is a 2nd order (Newton) optimisation
routine that uses a Cholesky decomposition of the hessian matrix.
< 8000 span class='blob-code-inner blob-code-marker ' data-code-marker="+"> When `n_samples >> n_features`, the `"newton-cholesky"` solver has been observed to
converge both faster and to a higher precision solution than the `"lbfgs"` solver on
problems with one-hot encoded categorical variables with some rare categorical
levels.
:pr:`24637` by :user:`Christian Lorentzen <lorentzenchr>`.

- |Enhancement| :class:`linear_model.GammaRegressor`,
:class:`linear_model.PoissonRegressor` and :class:`linear_model.TweedieRegressor`
can reach higher precision with the lbfgs solver, in particular when `tol` is set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@
linear_model_preprocessor = ColumnTransformer(
[
("passthrough_numeric", "passthrough", ["BonusMalus"]),
("binned_numeric", KBinsDiscretizer(n_bins=10), ["VehAge", "DrivAge"]),
(
"binned_numeric",
KBinsDiscretizer(n_bins=10, subsample=int(2e5), random_state=0),
["VehAge", "DrivAge"],
),
("log_scaled_numeric", log_scale_transformer, ["Density"]),
(
"onehot_categorical",
Expand Down Expand Up @@ -247,7 +251,7 @@ def score_estimator(estimator, df_test):
poisson_glm = Pipeline(
[
("preprocessor", linear_model_preprocessor),
("regressor", PoissonRegressor(alpha=1e-12, max_iter=300)),
("regressor", PoissonRegressor(alpha=1e-12, solver="newton-cholesky")),
]
)
poisson_glm.fit(
Expand Down
92 changes: 76 additions & 16 deletions examples/linear_model/plot_tweedie_regression_insurance_claims.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@
from sklearn.metrics import mean_squared_error


def load_mtpl2(n_samples=100000):
def load_mtpl2(n_samples=None):
"""Fetch the French Motor Third-Party Liability Claims dataset.

Parameters
----------
n_samples: int, default=100000
n_samples: int, default=None
number of samples to select (for faster run time). Full dataset has
678013 samples.
"""
Expand Down Expand Up @@ -215,7 +215,7 @@ def score_estimator(
from sklearn.compose import ColumnTransformer


df = load_mtpl2(n_samples=60000)
df = load_mtpl2()

# Note: filter out claims with zero amount, as the severity model
# requires strictly positive target values.
Expand All @@ -233,7 +233,11 @@ def score_estimator(

column_trans = ColumnTransformer(
[
("binned_numeric", KBinsDiscretizer(n_bins=10), ["VehAge", "DrivAge"]),
(
"binned_numeric",
KBinsDiscretizer(n_bins=10, subsample=int(2e5), random_state=0),
["VehAge", "DrivAge"],
),
(
"onehot_categorical",
OneHotEncoder(),
Expand Down Expand Up @@ -276,10 +280,26 @@ def score_estimator(

df_train, df_test, X_train, X_test = train_test_split(df, X, random_state=0)

# %%
#
# Let us keep in mind that despite the seemingly large number of data points in
# this dataset, the number of evaluation points where the claim amount is
# non-zero is quite small:
len(df_test)

# %%
len(df_test[df_test["ClaimAmount"] > 0])

# %%
#
# As a consequence, we expect a significant variability in our
# evaluation upon random resampling of the train test split.
#
# The parameters of the model are estimated by minimizing the Poisson deviance
# on the training set via a quasi-Newton solver: l-BFGS. Some of the features
# are collinear, we use a weak penalization to avoid numerical issues.
glm_freq = PoissonRegressor(alpha=1e-3, max_iter=400)
# on the training set via a Newton solver. Some of the features are collinear
# (e.g. because we did not drop any categorical level in the `OneHotEncoder`),
# we use a weak L2 penalization to avoid numerical issues.
glm_freq = PoissonRegressor(alpha=1e-4, solver="newton-cholesky")
glm_freq.fit(X_train, df_train["Frequency"], sample_weight=df_train["Exposure"])

scores = score_estimator(
Expand All @@ -295,6 +315,12 @@ def score_estimator(
print(scores)

# %%
#
# Note that the score measured on the test set is surprisingly better than on
# the training set. This might be specific to this random train-test split.
# Proper cross-validation could help us to assess the sampling variability of
# these results.
#
# We can visually compare observed and predicted values, aggregated by the
# drivers age (``DrivAge``), vehicle age (``VehAge``) and the insurance
# bonus/malus (``BonusMalus``).
Expand Down Expand Up @@ -374,7 +400,7 @@ def score_estimator(
mask_train = df_train["ClaimAmount"] > 0
mask_test = df_test["ClaimAmount"] > 0

glm_sev = GammaRegressor(alpha=10.0, max_iter=10000)
glm_sev = GammaRegressor(alpha=10.0, solver="newton-cholesky")

glm_sev.fit(
X_train[mask_train.values],
Expand All @@ -395,13 +421,44 @@ def score_estimator(
print(scores)

# %%
# Here, the scores for the test data call for caution as they are
# significantly worse than for the training data indicating an overfit despite
# the strong regularization.
#
# Note that the resulting model is the average claim amount per claim. As
# such, it is conditional on having at least one claim, and cannot be used to
# predict the average claim amount per policy in general.
# Those values of the metrics are not necessarily easy to interpret. It can be
# insightful to compare them with a model that does not use any input
# features and always predicts a constant value, i.e. the average claim
# amount, in the same setting:

from sklearn.dummy import DummyRegressor

dummy_sev = DummyRegressor(strategy="mean")
dummy_sev.fit(
X_train[mask_train.values],
df_train.loc[mask_train, "AvgClaimAmount"],
sample_weight=df_train.loc[mask_train, "ClaimNb"],
)

scores = score_estimator(
dummy_sev,
X_train[mask_train.values],
X_test[mask_test.values],
df_train[mask_train],
df_test[mask_test],
target="AvgClaimAmount",
weights="ClaimNb",
)
print("Evaluation of a mean predictor on target AvgClaimAmount")
print(scores)

# %%
#
# We conclude that the claim amount is very challenging to predict. Still, the
# :class:`~sklearn.linear.GammaRegressor` is able to leverage some information
# from the input features to slighly improve upon the mean baseline in terms
# of D².
#
# Note that the resulting model is the average claim amount per claim. As such,
# it is conditional on having at least one claim, and cannot be used to predict
# the average claim amount per policy. For this, it needs to be combined with
# a claims frequency model.

print(
"Mean AvgClaim Amount per policy: %.2f "
Expand All @@ -415,7 +472,10 @@ def score_estimator(
"Predicted Mean AvgClaim Amount | NbClaim > 0: %.2f"
% glm_sev.predict(X_train).mean()
)

print(
"Predicted Mean AvgClaim Amount (dummy) | NbClaim > 0: %.2f"
% dummy_sev.predict(X_train).mean()
)

# %%
# We can visually compare observed and predicted values, aggregated for
Expand Down Expand Up @@ -481,7 +541,7 @@ def score_estimator(
from sklearn.linear_model import TweedieRegressor


glm_pure_premium = TweedieRegressor(power=1.9, alpha=0.1, max_iter=10000)
glm_pure_premium = TweedieRegressor(power=1.9, alpha=0.1, solver="newton-cholesky")
glm_pure_premium.fit(
X_train, df_train["PurePremium"], sample_weight=df_train["Exposure"]
)
Expand Down
Loading
0