8000 ENH increase conlim to make more tests pass · scikit-learn/scikit-learn@83ce34f · GitHub
[go: up one dir, main page]

Skip to content

Commit 83ce34f

Browse files
committed
ENH increase conlim to make more tests pass
1 parent 39030c4 commit 83ce34f

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

sklearn/linear_model/_glm/_newton_solver.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,9 @@ def inner_solve(self, X, y, sample_weight):
974974
atol=eta * norm_G / (self.A_norm * self.r_norm),
975975
btol=self.tol,
976976
maxiter=max(n_samples, n_features) * n_classes, # default is min(A.shape)
977+
# default conlim = 1e8, for compatible systems 1e12 is still reasonable,
978+
# see LSMR documentation
979+
conlim=1e12,
977980
show=self.verbose >= 3,
978981
)
979982
# We store the estimated Frobenius norm of A and norm of residual r in
@@ -988,6 +991,8 @@ def inner_solve(self, X, y, sample_weight):
988991
conda,
989992
normx,
990993
) = result
994+
if self.verbose >= 2:
995+
print(f" Inner iterations in LSMR = {itn}")
991996
if self.coef.dtype == np.float32:
992997
self.coef_newton = self.coef_newton.astype(np.float32)
993998
if not self.linear_loss.base_loss.is_multiclass:
@@ -1010,7 +1015,7 @@ def inner_solve(self, X, y, sample_weight):
10101015
if self.iteration == 1:
10111016
return
10121017
# Note: We could detect too large steps by comparing norm(coef_newton) = normx
1013-
# with norm(gradient) o with the already available condition number of A, e.g.
1018+
# with norm(gradient) or with the already available condition number of A, e.g.
10141019
# conda.
10151020
if istop == 7:
10161021
self.use_fallback_lbfgs_solve = True
@@ -1033,7 +1038,7 @@ def inner_solve(self, X, y, sample_weight):
10331038
msg
10341039
+ "It will now resort to lbfgs instead.\n"
10351040
"This may be caused by singular or very ill-conditioned Hessian "
1036-
" matrix. "
1041+
"matrix. "
10371042
"Further options are to use another solver or to avoid such situation "
10381043
"in the first place. Possible remedies are removing collinear features"
10391044
"of X or increasing the penalization strengths.",

sklearn/linear_model/_glm/tests/test_glm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,8 +1093,8 @@ def test_solver_on_ill_conditioned_X(
10931093
np.exp(X_orig @ np.ones(X_orig.shape[1])), size=X_orig.shape[0]
10941094
).astype(np.float64)
10951095

1096-
tol = 1e-7
1097-
model = PoissonRegressor(solver=solver, alpha=0.0, tol=tol)
1096+
tol = 1e-8
1097+
model = PoissonRegressor(solver=solver, alpha=0.0, tol=tol, max_iter=200)
10981098

10991099
# No warning raised on well-conditioned design, even without regularization.
11001100
with warnings.catch_warnings():
@@ -1122,8 +1122,8 @@ def test_solver_on_ill_conditioned_X(
11221122

11231123
# Construct another ill-conditioned problem by scaling of features.
11241124
X_ill_conditioned = X_orig.copy()
1125-
X_ill_conditioned[:, 0] *= 1e-4
1126-
X_ill_conditioned[:, 1] *= 1e4
1125+
X_ill_conditioned[:, 0] *= 1e-6
1126+
X_ill_conditioned[:, 1] *= 1e2 # too large X may overflow in link function
11271127
# Make sure that it is ill conditioned >=> large condition number.
11281128
assert np.linalg.cond(X_ill_conditioned) > 1e5 * np.linalg.cond(X_orig)
11291129

@@ -1144,7 +1144,7 @@ def test_solver_on_ill_conditioned_X(
11441144
if test_loss:
11451145
# Without penalty, scaling of columns has no effect on predictions.
11461146
ill_cond_deviance = mean_poisson_deviance(y, reg.predict(X_ill_conditioned))
1147-
if solver in ("lbfgs", "newton-cholesky", "newton-lsmr"):
1147+
if solver in ("lbfgs", "newton-cholesky"):
11481148
pytest.xfail(
11491149
f"Solver {solver} does not converge but does so without warning."
11501150
)

0 commit comments

Comments
 (0)
0