8000 TST Extend tests for `scipy.sparse.*array` in `sklearn/linear_model/t… · scikit-learn/scikit-learn@7ad7090 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7ad7090

Browse files
authored
TST Extend tests for scipy.sparse.*array in sklearn/linear_model/tests/test_linear_loss.py (#27133)
Signed-off-by: Nurseit Kamchyev <bncer.ml@gmail.com>
1 parent 2720ce7 commit 7ad7090

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

sklearn/linear_model/tests/test_linear_loss.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import pytest
99
from numpy.testing import assert_allclose
10-
from scipy import linalg, optimize, sparse
10+
from scipy import linalg, optimize
1111

1212
from sklearn._loss.loss import (
1313
HalfBinomialLoss,
@@ -17,6 +17,7 @@
1717
from sklearn.datasets import make_low_rank_matrix
1818
from sklearn.linear_model._linear_loss import LinearModelLoss
1919
from sklearn.utils.extmath import squared_norm
20+
from sklearn.utils.fixes import CSR_CONTAINERS
2021

2122
# We do not need to test all losses, just what LinearModelLoss does on top of the
2223
# base losses.
@@ -104,8 +105,9 @@ def test_init_zero_coef(base_loss, fit_intercept, n_features, dtype):
104105
@pytest.mark.parametrize("fit_intercept", [False, True])
105106
@pytest.mark.parametrize("sample_weight", [None, "range"])
106107
@pytest.mark.parametrize("l2_reg_strength", [0, 1])
108+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
107109
def test_loss_grad_hess_are_the_same(
108-
base_loss, fit_intercept, sample_weight, l2_reg_strength
110+
base_loss, fit_intercept, sample_weight, l2_reg_strength, csr_container
109111
):
110112
"""Test that loss and gradient are the same across different functions."""
111113
loss = LinearModelLoss(base_loss=base_loss(), fit_intercept=fit_intercept)
@@ -150,7 +152,7 @@ def test_loss_grad_hess_are_the_same(
150152
assert_allclose(h4 @ g4, h3(g3))
15 E627 1153

152154
# same for sparse X
153-
X = sparse.csr_matrix(X)
155+
X = csr_container(X)
154156
l1_sp = loss.loss(
155157
coef, X, y, sample_weight=sample_weight, l2_reg_strength=l2_reg_strength
156158
)
@@ -182,9 +184,9 @@ def test_loss_grad_hess_are_the_same(
182184
@pytest.mark.parametrize("base_loss", LOSSES)
183185
@pytest.mark.parametrize("sample_weight", [None, "range"])
184186
@pytest.mark.parametrize("l2_reg_strength", [0, 1])
185-
@pytest.mark.parametrize("X_sparse", [False, True])
187+
@pytest.mark.parametrize("X_container", CSR_CONTAINERS + [None])
186188
def test_loss_gradients_hessp_intercept(
187-
base_loss, sample_weight, l2_reg_strength, X_sparse
189+
base_loss, sample_weight, l2_reg_strength, X_container
188190
):
189191
"""Test that loss and gradient handle intercept correctly."""
190192
loss = LinearModelLoss(base_loss=base_loss(), fit_intercept=False)
@@ -199,8 +201,8 @@ def test_loss_gradients_hessp_intercept(
199201
:, :-1
200202
] # exclude intercept column as it is added automatically by loss_inter
201203

202-
if X_sparse:
203-
X = sparse.csr_matrix(X)
204+
if X_container is not None:
205+
X = X_container(X)
204206

205207
if sample_weight == "range":
206208
sample_weight = np.linspace(1, y.shape[0], num=y.shape[0])

0 commit comments

Comments
 (0)
0