8000 CLN fix dtype and tests · scikit-learn/scikit-learn@1d13089 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1d13089

Browse files
committed
CLN fix dtype and tests
- keep dtype float32 after LSMR - lower test precision in test_NewtonLSMRSolver_multinomial_A_b_on_3_classes
1 parent 86da909 commit 1d13089

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

sklearn/linear_model/_glm/_newton_solver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,8 @@ def inner_solve(self, X, y, sample_weight):
918918
conda,
919919
normx,
920920
) = result
921+
if self.coef.dtype == np.float32:
922+
self.coef_newton = self.coef_newton.astype(np.float32)
921923
if not self.linear_loss.base_loss.is_multiclass:
922924
self.gradient_times_newton = self.gradient @ self.coef_newton
923925
else:

sklearn/linear_model/_glm/tests/test_glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1512,7 +1512,7 @@ def test_NewtonLSMRSolver_multinomial_A_b_on_3_classes(
15121512
(n_classes, -1)
15131513
)
15141514
# sum over classes = 0
1515-
assert_allclose(res1.sum(axis=0), 0, atol=5e-10)
1515+
assert_allclose(res1.sum(axis=0), 0, atol=5e-9)
15161516
assert_allclose((H_ext + pen_ext) @ res1.ravel(), -(G_ext + pen_ext @ coef_ext))
15171517
assert_allclose(A.T @ A @ res1.ravel(order="C"), A.T @ b)
15181518
res2 = lsmr(A, b, maxiter=(n_features + n_samples) * n_classes, show=True)

0 commit comments

Comments
 (0)
0